Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6301476
support only sync lora weight
hjh0119 Sep 8, 2025
c7be012
fix wip
hjh0119 Sep 8, 2025
22042fc
wip
hjh0119 Sep 8, 2025
1081caa
fix colocate lora
hjh0119 Sep 9, 2025
4c04d36
add lora for server wip
hjh0119 Sep 9, 2025
1982e9e
Merge branch 'lora+' of github.com:hjh0119/swift into lora+
hjh0119 Sep 9, 2025
5fe3690
fix import
hjh0119 Sep 9, 2025
161dac8
update extension path
hjh0119 Sep 9, 2025
0a14d20
override enable_lora for rollout
hjh0119 Sep 9, 2025
4574665
Merge branch 'lora+' of github.com:hjh0119/swift into lora+
hjh0119 Sep 9, 2025
efae3b2
catch rollout exception
hjh0119 Sep 9, 2025
f454598
fix lora request
hjh0119 Sep 9, 2025
d46bc1f
server wip
hjh0119 Sep 10, 2025
986ac8d
server add_lora wip
hjh0119 Sep 11, 2025
0dc8c6e
fix server tp
hjh0119 Sep 12, 2025
ba284ba
merge main
hjh0119 Sep 8, 2025
0f7ca2a
Merge branch 'lora+' of github.com:hjh0119/swift into lora+
hjh0119 Sep 12, 2025
849696f
doc wip
hjh0119 Sep 12, 2025
f70e827
doc
hjh0119 Sep 12, 2025
274f6db
check lora
hjh0119 Sep 12, 2025
f0b4de8
support only sync lora weight
hjh0119 Sep 12, 2025
6069888
add args for lora script
hjh0119 Sep 12, 2025
0cf5a62
update script
hjh0119 Sep 12, 2025
691a5df
fix
hjh0119 Sep 12, 2025
5cab78d
remove unused import
hjh0119 Sep 12, 2025
e43a0da
fix
hjh0119 Sep 12, 2025
688bf64
fix typo
hjh0119 Sep 12, 2025
4fa2d2f
fix unmerge
hjh0119 Sep 12, 2025
78f9473
wip
hjh0119 Sep 28, 2025
a50f756
Merge branch 'lora+' of github.com:hjh0119/swift into lora+
hjh0119 Sep 28, 2025
be00cf4
Merge remote-tracking branch 'origin' into lora+
hjh0119 Oct 9, 2025
6745666
bucket for full training in server mode
hjh0119 Oct 9, 2025
dfccf15
remove circle import
hjh0119 Oct 10, 2025
4652f77
fix TokenizerGroup removed in vllm 0.11.0
hjh0119 Oct 10, 2025
8c73590
rm comments
hjh0119 Oct 10, 2025
4389268
Merge branch 'lora+' of github.com:hjh0119/swift into lora+
hjh0119 Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/BestPractices/GRPO代码训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
```bash
CUDA_VISIBLE_DEVICES=7 \
swift rollout \
--model Qwen/Qwen2.5-7B-Instruct
--model Qwen/Qwen2.5-7B-Instruct \
--vllm_enable_lora true \
--vllm_max_lora_rank 16
```

```bash
Expand All @@ -61,6 +63,8 @@ swift rlhf \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000 \
--train_type lora \
--lora_rank 16 \
--lora_alpha 32 \
--torch_dtype bfloat16 \
--dataset 'open-r1/verifiable-coding-problems-python-10k' \
--load_from_cache_file true \
Expand Down
13 changes: 12 additions & 1 deletion docs/source/Instruction/GRPO/GetStarted/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ swift rollout \

更多 rollout 参数参考[vLLM参数](../../../Instruction/命令行参数.md#vllm参数)和[rollout 参数](../../../Instruction/命令行参数.md#rollout参数)

注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP
注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP,或升级vLLM


训练使用以下参数配置外部 vLLM 服务器
Expand All @@ -196,6 +196,17 @@ swift rollout \
--vllm_server_port <服务端口> \
--vllm_server_timeout <超时时间> \
```
#### 权重同步加速
swift 3.9对 LoRA 训练的权重同步进行了优化(相比swift3.8加速约10倍)

为开启LoRA权重同步优化,请在rollout命令中设置以下参数
```bash
--vllm_enable_lora true
--vllm_max_lora_rank xxx # 与训练脚本lora_rank一致
```
注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模态模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false`

优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773)

## logged metrics
- completions/mean_length:生成的 completion 的平均长度。
Expand Down
2 changes: 2 additions & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ soft overlong 奖励参数
- Rollout 参数
- multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。
- max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。
- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)
- vllm_max_lora_rank: vLLM Engine LoRA参数,需大于等于训练的lora_rank,建议等于。默认为16。

### Rollout参数
Rollout参数继承于[部署参数](#部署参数)
Expand Down
6 changes: 5 additions & 1 deletion docs/source_en/BestPractices/GRPO-Code-Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ launch external vLLM server using following script
```bash
CUDA_VISIBLE_DEVICES=7 \
swift rollout \
--model Qwen/Qwen2.5-7B-Instruct
--model Qwen/Qwen2.5-7B-Instruct \
--vllm_enable_lora true \
--vllm_max_lora_rank 16
```

```bash
Expand All @@ -65,6 +67,8 @@ swift rlhf \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000 \
--train_type lora \
--lora_rank 16 \
--lora_alpha 32 \
--torch_dtype bfloat16 \
--dataset 'open-r1/verifiable-coding-problems-python-10k' \
--load_from_cache_file true \
Expand Down
2 changes: 2 additions & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments
- Rollout Parameters
- multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py.
- max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit.
- vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details.
- vllm_max_lora_rank: LoRA parameter for the vLLM engine. Must be greater than or equal to the training lora_rank; it is recommended to set them equal. Defaults to 16.

### Rollout Arguments
The rollout parameters inherit from the [deployment parameters](#deployment-arguments).
Expand Down
14 changes: 14 additions & 0 deletions docs/source_en/Instruction/GRPO/GetStarted/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ To configure the external vLLM server during training, use the following paramet
--vllm_server_port <service_port> \
--vllm_server_timeout <timeout> \
```

#### Weight-Sync Acceleration
Swift 3.9 optimizes weight synchronization for LoRA training, achieving ~10× speed-up over Swift 3.8.

To enable the optimized LoRA weight sync, add the following arguments to your rollout command:

```bash
--vllm_enable_lora true
--vllm_max_lora_rank xxx # set to the same value as lora_rank in the training script
```
Note: For multimodal model training, vLLM supports loading adapters only for the language-model part. If you need to train the ViT layers of a multimodal model (freeze_vit false), set `vllm_enable_lora false`.

For implementation details, please refer to the [PR](https://github.com/modelscope/ms-swift/pull/5773)

## logged metrics
- completions/mean_length: The average length of generated completions.
- completions/min_length: The minimum length among generated completions.
Expand Down
6 changes: 6 additions & 0 deletions examples/train/grpo/external/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
1. vLLM version 0.8.3 or higher.
2. trl version 0.17.0 or higher

For LoRA Training, set following parameters to speed up weight update
```bash
--vllm_enable_lora true
--vllm_max_lora_rank xxx # same as lora_rank in training script
```

## **Introduction**

The GRPO (Group Relative Policy Optimization) training framework supports high-performance inference engines like vLLM to accelerate the sampling process. The **External Mode** allows you to connect to an external vLLM inference server, separating the inference service from the training process. This mode is ideal for scenarios where you want to offload inference to dedicated hardware or servers, improving resource utilization and scalability.
Expand Down
52 changes: 52 additions & 0 deletions examples/train/grpo/external/mllm_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# For LoRA Training, set following parameters to speed up weight update
# ```bash
# --vllm_enable_lora true
# --vllm_max_lora_rank xxx # same as lora_rank in training script
# ```

# CUDA_VISIBLE_DEVICES=4,5,6,7 \
# swift rollout \
# --model Qwen/Qwen2.5-VL-7B-Instruct \
# --vllm_data_parallel_size 2 \
# --vllm_tensor_parallel_size 2 \
# --vllm_enable_lora true \
# --vllm_max_lora_rank 16


CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=4 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--external_plugins examples/train/grpo/plugin/plugin.py \
--reward_funcs external_r1v_acc format \
--use_vllm true \
--vllm_mode server \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000 \
--train_type lora \
--lora_rank 16 \
--lora_alpha 32 \
--torch_dtype bfloat16 \
--dataset 'AI-ModelScope/clevr_cogen_a_train' \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 2 \
--save_strategy 'steps' \
--eval_strategy 'steps' \
--eval_steps 1000 \
--save_steps 1000 \
--save_total_limit 10 \
--logging_steps 1 \
--warmup_ratio 0.01 \
--dataloader_num_workers 4 \
--num_generations 16 \
--temperature 1.0 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero3 \
--log_completions true \
--report_to tensorboard swanlab \
--num_iterations 1 \
--beta 0.001
3 changes: 2 additions & 1 deletion swift/llm/argument/deploy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class RolloutArguments(DeployArguments):
# only for GRPO rollout with AsyncEngine, see details in swift/plugin/multi_turn
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None

vllm_enable_lora: bool = False
vllm_max_lora_rank: int = 16
# GYM env
gym_env: Optional[str] = None
context_manager: Optional[str] = None
Expand Down
21 changes: 21 additions & 0 deletions swift/llm/infer/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
try:
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400'
from vllm.lora.request import LoRARequest
except Exception:
raise

Expand Down Expand Up @@ -98,6 +99,16 @@ def infer(
use_tqdm: Optional[bool] = None,
adapter_request: Optional[AdapterRequest] = None,
) -> List[RolloutOutput]:
if not adapter_request and self.enable_lora:
lora_int_ids = list(self.engine.list_loras())
if lora_int_ids:
# since max_lora = 1, pick the first lora
adapter_request = LoRARequest(
lora_name=f'lora_{lora_int_ids[0]}',
lora_int_id=lora_int_ids[0],
lora_path='dummy_lora_path',
)

res = super().infer(
infer_requests,
request_config,
Expand Down Expand Up @@ -189,3 +200,13 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r
id=request_id,
prompt_token_ids=prompt_token_ids,
images_size=images_size)

def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, LoRARequest]] = None):
assert self.enable_lora, f'adapter_request: {adapter_request}, self.enable_lora: {self.enable_lora}'
from vllm.lora.request import LoRARequest
if isinstance(adapter_request, AdapterRequest):
return super()._add_adapter(adapter_request)
elif isinstance(adapter_request, LoRARequest):
return adapter_request
else:
raise ValueError(f'Invalid adapter request: {adapter_request}')
12 changes: 12 additions & 0 deletions swift/llm/infer/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from PIL import Image
from pydantic import BaseModel, Field, field_validator

from swift.trainers.rlhf_trainer.utils import FlattenedTensorMetadata
from swift.tuners.lora import LoraConfig
from ..template import InferRequest
from ..utils import Messages, Tool

Expand Down Expand Up @@ -459,3 +461,13 @@ class UpdateWeightsRequest(BaseModel):
name: str
dtype: str
shape: list[int]


class UpdateFlattenedAdapterRequest(BaseModel):
lora_int_id: int
peft_config: LoraConfig
metadatas: List[FlattenedTensorMetadata]


class UpdateFlattenedParamsRequest(BaseModel):
metadatas: List[FlattenedTensorMetadata]
Loading
Loading