diff --git "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" index 67f12e0165..ac00f9d180 100644 --- "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" @@ -42,7 +42,9 @@ ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -61,6 +63,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --load_from_cache_file true \ diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index f0bb51bdf4..d2a8dd16ae 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -185,7 +185,7 @@ swift rollout \ 更多 rollout 参数参考[vLLM参数](../../../Instruction/命令行参数.md#vllm参数)和[rollout 参数](../../../Instruction/命令行参数.md#rollout参数) -注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP。 +注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP,或升级vLLM 训练使用以下参数配置外部 vLLM 服务器 @@ -196,6 +196,17 @@ swift rollout \ --vllm_server_port <服务端口> \ --vllm_server_timeout <超时时间> \ ``` +#### 权重同步加速 +swift 3.9对 LoRA 训练的权重同步进行了优化(相比swift3.8加速约10倍) + +为开启LoRA权重同步优化,请在rollout命令中设置以下参数 +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 +``` +注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模态模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false` + +优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773) ## logged metrics - completions/mean_length:生成的 completion 的平均长度。 diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 7917eddce2..66afff6745 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -606,6 +606,8 @@ soft overlong 奖励参数 Rollout参数继承于[部署参数](#部署参数) - multi_turn_scheduler: 多轮GRPO训练规划器,传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。默认为None,具体参考[文档](./GRPO/DeveloperGuide/多轮训练.md) - max_turns: 多轮GRPO训练下的最大轮数,默认为None,即不做约束。 +- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速) +- vllm_max_lora_rank: vLLM Engine LoRA参数,需大于等于训练的lora_rank,建议等于。默认为16。 ### Web-UI参数 - server_name: web-ui的host,默认为'0.0.0.0'。 diff --git a/docs/source_en/BestPractices/GRPO-Code-Training.md b/docs/source_en/BestPractices/GRPO-Code-Training.md index aa24f56985..3b822ec97c 100644 --- a/docs/source_en/BestPractices/GRPO-Code-Training.md +++ b/docs/source_en/BestPractices/GRPO-Code-Training.md @@ -46,7 +46,9 @@ launch external vLLM server using following script ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -65,6 +67,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --load_from_cache_file true \ diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index ead203be2d..d609b21066 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -623,6 +623,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments The rollout parameters inherit from the [deployment parameters](#deployment-arguments). - multi_turn_scheduler: The scheduler for multi-turn GRPO training. Pass the corresponding plugin name, and ensure the implementation is added in `plugin/multi_turn.py`. Default is `None`. See [documentation](./GRPO/DeveloperGuide/multi_turn.md) for details. - max_turns: Maximum number of turns in multi-turn GRPO training. Default is `None`, meaning no limit. +- vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. +- vllm_max_lora_rank: LoRA parameter for the vLLM engine. Must be greater than or equal to the training lora_rank; it is recommended to set them equal. Defaults to 16. ### Web-UI Arguments - server_name: Host for the web UI, default is '0.0.0.0'. diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index 73b4772005..80a251e870 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -194,6 +194,20 @@ To configure the external vLLM server during training, use the following paramet --vllm_server_port \ --vllm_server_timeout \ ``` + +#### Weight-Sync Acceleration +Swift 3.9 optimizes weight synchronization for LoRA training, achieving ~10× speed-up over Swift 3.8. + +To enable the optimized LoRA weight sync, add the following arguments to your rollout command: + +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # set to the same value as lora_rank in the training script +``` +Note: For multimodal model training, vLLM supports loading adapters only for the language-model part. If you need to train the ViT layers of a multimodal model (freeze_vit false), set `vllm_enable_lora false`. + +For implementation details, please refer to the [PR](https://github.com/modelscope/ms-swift/pull/5773) + ## logged metrics - completions/mean_length: The average length of generated completions. - completions/min_length: The minimum length among generated completions. diff --git a/examples/train/grpo/external/README.md b/examples/train/grpo/external/README.md index 733199dd4c..a4808a5c9d 100644 --- a/examples/train/grpo/external/README.md +++ b/examples/train/grpo/external/README.md @@ -7,6 +7,12 @@ 1. vLLM version 0.8.3 or higher. 2. trl version 0.17.0 or higher +For LoRA Training, set following parameters to speed up weight update +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # same as lora_rank in training script +``` + ## **Introduction** The GRPO (Group Relative Policy Optimization) training framework supports high-performance inference engines like vLLM to accelerate the sampling process. The **External Mode** allows you to connect to an external vLLM inference server, separating the inference service from the training process. This mode is ideal for scenarios where you want to offload inference to dedicated hardware or servers, improving resource utilization and scalability. diff --git a/examples/train/grpo/external/mllm_lora.sh b/examples/train/grpo/external/mllm_lora.sh new file mode 100644 index 0000000000..e2b4e38bf3 --- /dev/null +++ b/examples/train/grpo/external/mllm_lora.sh @@ -0,0 +1,52 @@ +# For LoRA Training, set following parameters to speed up weight update +# ```bash +# --vllm_enable_lora true +# --vllm_max_lora_rank xxx # same as lora_rank in training script +# ``` + +# CUDA_VISIBLE_DEVICES=4,5,6,7 \ +# swift rollout \ +# --model Qwen/Qwen2.5-VL-7B-Instruct \ +# --vllm_data_parallel_size 2 \ +# --vllm_tensor_parallel_size 2 \ +# --vllm_enable_lora true \ +# --vllm_max_lora_rank 16 + + +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NPROC_PER_NODE=4 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ + --torch_dtype bfloat16 \ + --dataset 'AI-ModelScope/clevr_cogen_a_train' \ + --max_completion_length 1024 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 2 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --system 'examples/train/grpo/prompt.txt' \ + --deepspeed zero3 \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 diff --git a/swift/llm/argument/deploy_args.py b/swift/llm/argument/deploy_args.py index c6a762bec0..71ae22b2bd 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/llm/argument/deploy_args.py @@ -86,7 +86,8 @@ class RolloutArguments(DeployArguments): # only for GRPO rollout with AsyncEngine, see details in swift/plugin/multi_turn multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None - + vllm_enable_lora: bool = False + vllm_max_lora_rank: int = 16 # GYM env gym_env: Optional[str] = None context_manager: Optional[str] = None diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index ac546f4b0d..6274f4d0ca 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -14,6 +14,7 @@ try: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' + from vllm.lora.request import LoRARequest except Exception: raise @@ -98,6 +99,16 @@ def infer( use_tqdm: Optional[bool] = None, adapter_request: Optional[AdapterRequest] = None, ) -> List[RolloutOutput]: + if not adapter_request and self.enable_lora: + lora_int_ids = list(self.engine.list_loras()) + if lora_int_ids: + # since max_lora = 1, pick the first lora + adapter_request = LoRARequest( + lora_name=f'lora_{lora_int_ids[0]}', + lora_int_id=lora_int_ids[0], + lora_path='dummy_lora_path', + ) + res = super().infer( infer_requests, request_config, @@ -189,3 +200,13 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r id=request_id, prompt_token_ids=prompt_token_ids, images_size=images_size) + + def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, LoRARequest]] = None): + assert self.enable_lora, f'adapter_request: {adapter_request}, self.enable_lora: {self.enable_lora}' + from vllm.lora.request import LoRARequest + if isinstance(adapter_request, AdapterRequest): + return super()._add_adapter(adapter_request) + elif isinstance(adapter_request, LoRARequest): + return adapter_request + else: + raise ValueError(f'Invalid adapter request: {adapter_request}') diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index e2e0d02783..a3926df386 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -12,6 +12,8 @@ from PIL import Image from pydantic import BaseModel, Field, field_validator +from swift.trainers.rlhf_trainer.utils import FlattenedTensorMetadata +from swift.tuners.lora import LoraConfig from ..template import InferRequest from ..utils import Messages, Tool @@ -459,3 +461,13 @@ class UpdateWeightsRequest(BaseModel): name: str dtype: str shape: list[int] + + +class UpdateFlattenedAdapterRequest(BaseModel): + lora_int_id: int + peft_config: LoraConfig + metadatas: List[FlattenedTensorMetadata] + + +class UpdateFlattenedParamsRequest(BaseModel): + metadatas: List[FlattenedTensorMetadata] diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index d56c18e301..205220b682 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -6,26 +6,31 @@ import multiprocessing import os import time +import traceback +from collections.abc import Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import asdict from functools import wraps from itertools import chain from multiprocessing import Pipe, Process from multiprocessing.connection import Connection -from typing import Dict, List, Optional, Union, get_type_hints +from typing import Dict, List, Optional, Union import torch import uvicorn from aiohttp import ClientConnectorError from fastapi import FastAPI -from trl.scripts.vllm_serve import WeightSyncWorkerExtension +from trl.scripts.vllm_serve import WeightSyncWorkerExtension as HFWeightSyncWorkerExtension from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest from swift.plugin.multi_turn import RolloutScheduler, multi_turns +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest, + patch_vllm_load_adapter) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient -from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest +from .protocol import (InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest, + UpdateFlattenedParamsRequest, UpdateWeightsRequest) try: from vllm.utils import get_open_port @@ -50,6 +55,84 @@ - For inference or deployment, please use the `swift infer` or `swift deploy` commands. """ +patch_vllm_load_adapter() + + +class WeightSyncWorkerExtension(HFWeightSyncWorkerExtension): + + def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + + Args: + name (`str`): + Name of the weight tensor being updated. + dtype (`str`): + Data type of the weight tensor as a string (e.g., `"torch.float32"`). + shape (`Sequence[int]`): + Shape of the weight tensor. + """ + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + dtype = getattr(torch, dtype.split('.')[-1]) + # Allocate memory for the incoming weight tensor on the correct device. + weight = torch.empty(shape, dtype=dtype, device=self.device) + + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.pynccl_comm.broadcast(weight, src=self.client_rank) + self.pynccl_comm.group.barrier() + + # Load the received weights into the model. + self.model_runner.model.load_weights(weights=[(name, weight)]) + + def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, metadatas: list[Dict]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + """ + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + lora_request = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=peft_config, + lora_tensors=named_params) + self.add_lora(lora_request) + + def update_flattened_params(self, metadatas: list[Dict]) -> None: + """ + Receives updated flattened weights from the client process and updates the model parameters. + + Args: + metadatas (list[Dict]): List of metadata dictionaries for the flattened tensors. + """ + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) + + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + + # Load the reconstructed parameters into the model + self.model_runner.model.load_weights(weights=list(named_params.items())) + + logger = get_logger() @@ -109,7 +192,11 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) - result = method(*args, **kwargs) + try: + result = method(*args, **kwargs) + except Exception: + logger.error(f'Method execution failed: {method_name}\n{traceback.format_exc()}') + result = None if command['type'] == 'call': connection.send(result) elif command['type'] == 'shutdown': @@ -138,7 +225,6 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast # Handle commands if command['type'] in ['call', 'fire_and_forget']: - import traceback method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) @@ -167,6 +253,8 @@ def _register_rl_rollout_app(self): self.app.get('/get_world_size/')(self.get_world_size) self.app.post('/init_communicator/')(self.init_communicator) self.app.post('/update_named_param/')(self.update_named_param) + self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param) + self.app.post('/update_flattened_params/')(self.update_flattened_params) self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache) self.app.post('/close_communicator/')(self.close_communicator) self.app.post('/infer/', response_model=None)(self.infer) @@ -224,16 +312,18 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, + 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend if infer_backend != 'vllm': infer_backend = 'vllm' logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend') kwargs.update(args.get_vllm_engine_kwargs()) + kwargs.update({'enable_lora': args.vllm_enable_lora}) # override # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync - engine_kwargs.update({'worker_extension_cls': 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'}) + engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'}) engine_kwargs['load_format'] = 'dummy' if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1: engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size @@ -311,6 +401,37 @@ async def update_named_param(self, request: UpdateWeightsRequest): return {'message': 'Request received, updating named parameter'} + async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRequest): + peft_config = asdict(request.peft_config) + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_adapter_flattened_param', 'args': (request.lora_int_id, peft_config, metadatas)} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating adapter parameter'} + + async def update_flattened_params(self, request: UpdateFlattenedParamsRequest): + """ + Updates the model weights with flattened tensor data. + + Args: + request (UpdateFlattenedParamsRequest): + - metadatas (List[FlattenedTensorMetadata]): Metadata for the flattened tensors. + + """ + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_flattened_params', 'args': (metadatas, )} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating flattened parameters'} + async def reset_prefix_cache(self): """ Resets the prefix cache for the model. @@ -342,13 +463,17 @@ async def get_engine_type(self): enable_multi_turn = False if self.args.multi_turn_scheduler: enable_multi_turn = True - - if self.use_async_engine: - if self.use_gym_env: - return {'engine_type': 'AsyncLLMEngine', 'gym_env': True, 'enable_multi_turn': True} - return {'engine_type': 'AsyncLLMEngine', 'enable_multi_turn': enable_multi_turn} - else: - return {'engine_type': 'LLMEngine', 'enable_multi_turn': enable_multi_turn} + use_gym_env = False + if self.use_async_engine and self.use_gym_env: + use_gym_env = True + engine_type = 'AsyncLLMEngine' if self.use_async_engine else 'LLMEngine' + enable_lora = self.args.vllm_enable_lora + return { + 'engine_type': engine_type, + 'enable_multi_turn': enable_multi_turn, + 'use_gym_env': use_gym_env, + 'enable_lora': enable_lora, + } async def close_communicator(self): """ @@ -429,22 +554,3 @@ def run_rollout(args: RolloutArguments, return_url: bool = False): finally: process.terminate() logger.info('The deployment process has been terminated.') - - -# https://github.com/huggingface/trl/pull/3690 -# This patch handles backward compatibility for dtype parameter type changes in TRL: -# - For TRL <= 0.19: dtype_annotation is torch.dtype (needs patching) -# - For TRL > 0.19: dtype_annotation is str (no patching needed) -old_update_named_param = WeightSyncWorkerExtension.update_named_param -dtype_annotation = get_type_hints(old_update_named_param).get('dtype') - -if not hasattr(WeightSyncWorkerExtension, 'old_update_named_param') and dtype_annotation == torch.dtype: - - @wraps(old_update_named_param) - def patched_update_named_param(self, name, dtype, shape) -> None: - if isinstance(dtype, str): - dtype = getattr(torch, dtype.split('.')[-1]) - return old_update_named_param(self, name, dtype, shape) - - WeightSyncWorkerExtension.update_named_param = patched_update_named_param - WeightSyncWorkerExtension.old_update_named_param = old_update_named_param diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index b8423dc340..89fb020e08 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -51,6 +51,7 @@ class GKDConfig(SwiftArgumentsMixin, HfGKDConfig): @dataclass class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): stop_words: List[str] = field(default_factory=list) + lora_rank: int = 8 # for vllm lora adapter def __post_init__(self): GRPOArgumentsMixin.__post_init__(self) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index f53004e536..c2e32e0e85 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -7,7 +7,7 @@ import re import time import uuid -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -25,6 +25,7 @@ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from dacite import from_dict from packaging import version +from peft.utils.save_and_load import get_peft_model_state_dict from torch.nn import ModuleList from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback @@ -48,10 +49,11 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (_ForwardRedirection, compute_chord_loss, identity_data_collator, load_pil_img, - make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids, - set_expandable_segments) +from .utils import (FlattenedTensorBucket, TensorLoRARequest, _create_parameter_buckets, _ForwardRedirection, + _process_bucket_with_flattened_tensor, compute_chord_loss, get_gather_if_zero3_context, + identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, + patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, + patch_vllm_load_adapter, replace_assistant_response_with_ids, set_expandable_segments) from .vllm_client import VLLMClient try: @@ -245,19 +247,20 @@ def __init__(self, # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but # it's safer to set it in all cases. set_seed(args.seed, device_specific=True) - if is_peft_model(self.model): - self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() self.use_fast_infer = self.use_vllm # whether to use the PT backend self.vllm_use_async_engine = False self.enable_offload = False self.use_gym_env = False self.enable_server_multi_turn = False + self.rollout_enable_lora = False # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs self.dynamic_num_samples = False if self.use_vllm: if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') + self.base_sync_done = False # tag for lora weights sync + if self.vllm_mode == 'server': self.vllm_client: VLLMClient = vllm_client if self.accelerator.is_main_process: @@ -265,13 +268,16 @@ def __init__(self, vllm_use_async_engine = [self.vllm_client.use_async_engine] use_gym_env = [self.vllm_client.use_gym_env] enable_multi_turn = [self.vllm_client.enable_multi_turn] + enable_lora = [self.vllm_client.enable_lora] else: vllm_use_async_engine = [False] use_gym_env = [False] enable_multi_turn = [self.enable_server_multi_turn] + enable_lora = [False] self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0] if self.use_gym_env: self.reward_func_names = ['gym_reward'] @@ -302,6 +308,8 @@ def __init__(self, infer_template.padding_free = False self.engine = PtEngine.from_model_template(self.model, infer_template, max_batch_size=0) # 0: no limit + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') @@ -486,7 +494,10 @@ def replace_lora(name): if 'lora_' in name: return '' else: - return name.replace('base_layer.', '') + if not self.rollout_enable_lora: + return re.sub(r'\.base_layer\.', '.', name) + else: + return name def remove_lora_and_prefix(names): names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) @@ -533,6 +544,15 @@ def prepare_vllm(self, model): self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) vllm_template = copy(self.template) vllm_template.padding_free = False + lora_kwargs = {} + if self.args.train_type == 'lora': + lora_kwargs = { + 'enable_lora': True, + 'max_loras': 1, + 'max_lora_rank': self.args.lora_rank, + } + self.rollout_enable_lora = True + patch_vllm_load_adapter() with Swift.grpo_context(model, self.template.processor): set_expandable_segments(False) engine = GRPOVllmEngine( @@ -553,6 +573,7 @@ def prepare_vllm(self, model): load_format='dummy', template=vllm_template, distributed_executor_backend='external_launcher', + **lora_kwargs, ) set_expandable_segments(True) return engine @@ -569,31 +590,75 @@ def _template_context(self, template: Template): @patch_profiling_decorator def _move_model_to_vllm(self, skip_async_check=False): - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - if self.args.async_generate and not skip_async_check: # before sync weight, we should wait async generate finish self._wait_queue() + train_type = self.args.train_type + + if train_type == 'full' or (train_type == 'lora' and not self.base_sync_done) or not self.rollout_enable_lora: + self._move_full_model_to_vllm() + else: + self._move_adapter_to_vllm() + + def _move_adapter_to_vllm(self): + lora_params = OrderedDict() + for i, parameter_group in enumerate(self.parameter_groups): # < this is the change + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + gather_if_zero3 = get_gather_if_zero3_context(self) + with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + assert len(parameters) == len(parameter_group) + state_dict = {name: p for p, name in zip(parameters, parameter_group)} + peft_config = self.model.peft_config.get('default', None) + self.model.merge_adapter() + cur_lora_params = get_peft_model_state_dict(self.model, state_dict) + cur_lora_params = { + name: param.full_tensor().detach() if hasattr(param, 'full_tensor') else param.detach() + for name, param in cur_lora_params.items() + } + lora_params.update(cur_lora_params) + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() + del cur_lora_params + + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + bucked = FlattenedTensorBucket(named_tensors=list(lora_params.items())) + metadatas = bucked.get_metadata() + flattened_tensor = bucked.get_flattened_tensor() + self.vllm_client.update_adapter_flattened_param(peft_config, metadatas, flattened_tensor) + elif self.vllm_mode == 'colocate': + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + lora_reqest = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=asdict(peft_config), + lora_tensors=lora_params, + ) + self.engine.llm_engine.add_lora(lora_reqest) + del lora_params + + def _move_full_model_to_vllm(self): + gather_if_zero3 = get_gather_if_zero3_context(self) if is_peft_model(self.model): - for i, parameter_group in enumerate(self.parameter_groups): # < this is the change + for i, parameter_group in enumerate(self.parameter_groups): parameter_group_no_lora = self.parameter_groups_no_lora[i] parameters = [ parameter for name, parameter in self.model.named_parameters() if not parameter_group or name in parameter_group ] with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): - self.model.merge_adapter() + if self.should_merge_adapter: + # if rollout enable lora, we will only execute once before the first rollout + self.model.merge_adapter() state_dict = self.model.state_dict() - state_dict = { - k.removeprefix('base_model.model.').replace('.base_layer', ''): v - for k, v in state_dict.items() + prefix_removed = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} + state_dict = prefix_removed if self.rollout_enable_lora else { + k.replace('.base_layer', ''): v + for k, v in prefix_removed.items() } state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} # When module to save, remove its prefix and discard the original module @@ -608,22 +673,68 @@ def _move_model_to_vllm(self, skip_async_check=False): [state.shape != torch.Size([0]) for state in state_dict.values()]) if self.vllm_mode == 'server' and self.accelerator.is_main_process: - for name, param in state_dict.items(): - self.vllm_client.update_named_param(name, param) + # Create parameter buckets and process them efficiently + named_params = list(state_dict.items()) + parameter_buckets = _create_parameter_buckets(named_params) + + # Process each bucket using flattened tensor approach + for bucket in parameter_buckets: + _process_bucket_with_flattened_tensor(self, bucket) + + del named_params, parameter_buckets elif self.vllm_mode == 'colocate': llm_model = self.engine.inner_model llm_model.load_weights(state_dict.items()) - with patch_lora_unmerge(self.model): - self.model.unmerge_adapter() + if self.should_merge_adapter: + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() del state_dict + self.base_sync_done = True else: - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights([(name, param.data)]) + if self.vllm_mode == 'server': + bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + with gather_if_zero3(parameters): + if self.accelerator.is_main_process: + # Get state_dict AFTER gather to get full parameters + state_dict = self.model.state_dict() + + # Filter by parameter_group_no_lora if specified + if parameter_group_no_lora: + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + + # Split gathered parameters into buckets + current_bucket = [] + current_size = 0 + + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, process current bucket first + if current_size + param_size > bucket_size_bytes and current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, param)) + current_size += param_size + + # Process remaining parameters in the last bucket + if current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) + + del state_dict + else: + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + if self.vllm_mode == 'colocate': + llm_model = self.engine.inner_model + llm_model.load_weights([(name, param.data)]) if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() @@ -2938,3 +3049,33 @@ def get_chunked_inputs(self, inputs, start_idx, end_idx): chunk_inputs.update(to_device(template.data_collator(encoded_data), self.model.device)) chunk_inputs.pop('labels', None) return chunk_inputs + + @property + def should_merge_adapter(self): + """ + Determine whether the LoRA adapter should be merged into the base model during weight synchronization. + + Note: + Merging or unmerging adapters in MoE models is computationally expensive and should be minimized. + + Raises: + AssertionError: If full-parameter training is used, as adapter merging is not supported. + + Returns: + bool: True if the adapter should be merged; False otherwise. + - Returns True when LoRA is not enabled for rollout. + - Returns True when loading from a checkpoint or using pre-trained adapters. + - Returns False during normal LoRA training (weights are already synchronized). + """ + assert self.args.train_type != 'full', 'Full-parameter training should not merge adapter' + + # Rollout does not support LoRA + if not self.rollout_enable_lora: + return True + + if self.args.resume_from_checkpoint: + # Resuming training: merge into base model + return True + + # base model weights are synced before training; no need to merge + return False diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index c168d69c3a..a17cb3f935 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -3,23 +3,27 @@ import math import os import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +from dataclasses import asdict, dataclass from functools import partial from io import BytesIO from types import MethodType -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import datasets +import json import torch import torch.nn.functional as F +from msgspec import field from peft.tuners import lora from peft.tuners.lora import LoraLayer from PIL import Image +from pydantic import BaseModel, field_validator from torch import nn from torch.utils.data import DataLoader, RandomSampler from transformers import Trainer -from swift.utils import is_swanlab_available, is_wandb_available +from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available if is_wandb_available(): import wandb @@ -29,6 +33,23 @@ if TYPE_CHECKING: from swift.llm.utils import Messages +TensorLoRARequest = None +if is_vllm_available(): + from vllm.lora.request import LoRARequest + + class TensorLoRARequest(LoRARequest): + peft_config: dict = field(default=None) + lora_tensors: dict = field(default=None) + lora_embeddings: Optional[Dict[str, torch.Tensor]] = None + + @property + def config(self): + return self.peft_config + + @property + def embeddings(self): + return self.lora_embeddings + def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. @@ -368,6 +389,231 @@ def patched_len(self) -> int: RepeatSampler.old_len_func = origin_len_func +def get_gather_if_zero3_context(trainer): + deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + return gather_if_zero3 + + +def patch_vllm_load_adapter(): + # from vllm.lora.worker_manager import WorkerLoRAManager + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + from vllm.lora.models import LoRAModel + from vllm.lora.utils import get_adapter_absolute_path + + try: + from vllm.transformers_utils.tokenizer_group import TokenizerGroup + except ImportError: + # removed in https://github.com/vllm-project/vllm/pull/24078 + TokenizerGroup = None + + def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: + """ + code borrowed from verl.utils.vllm.utils.py + based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors + Reason: + VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. + To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to + load memory-based LoRA tensors. + """ + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_modules: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + expected_lora_modules = list(set(expected_lora_modules)) + # this is the patch + lora_tensors = None + from vllm.lora.peft_helper import PEFTHelper + if isinstance(lora_request, TensorLoRARequest): + peft_config = lora_request.peft_config + lora_tensors = lora_request.lora_tensors + peft_helper = PEFTHelper.from_dict(peft_config) + else: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) + if isinstance(lora_request, TensorLoRARequest): # this is the patch + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_tensors, + peft_helper=peft_helper, + device='cpu', + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device='cpu', + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + except Exception as e: + raise e + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' + f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') + return lora + + def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest): + # since we pass dummy path, skip get tokenizer from path + return self.tokenizer + + if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): + _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter + LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter + LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter + if TokenizerGroup is not None: + TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer + TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer + + +# FlattenedTensor, code borrowed from sglang/srt/weight_sync/tensor_bucket.py +class FlattenedTensorMetadata(BaseModel): + """Metadata for a tensor in a flattened bucket""" + name: str + shape: Tuple[int, ...] + dtype: str + start_idx: int + end_idx: int + numel: int + + @field_validator('shape', mode='before') + @classmethod + def ensure_shape_tuple(cls, v: Any) -> Tuple[int, ...]: + # accept tuple/list, torch.Size, or other iterable of ints + if torch is not None and isinstance(v, torch.Size): + return tuple(int(x) for x in v) + if isinstance(v, (list, tuple)): + return tuple(int(x) for x in v) + if isinstance(v, Iterable): + return tuple(int(x) for x in v) + raise ValueError('shape must be an iterable of ints (e.g. tuple/list/torch.Size)') + + @field_validator('dtype', mode='before') + @classmethod + def ensure_dtype_str(cls, v: Any) -> str: + # accept torch.dtype or str + if torch is not None and isinstance(v, torch.dtype): + return str(v) + if isinstance(v, str): + return v + raise ValueError('dtype must be a torch.dtype or str') + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError('Cannot create empty tensor bucket') + + # First pass: compute total size and metadata + current_idx = 0 + total_numel = 0 + for i, (name, tensor) in enumerate(named_tensors): + numel = tensor.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype), + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + total_numel += numel + + # Pre-allocate the final flattened tensor to avoid intermediate copies + # Use the dtype and device of the first tensor + first_tensor = named_tensors[0][1] + self.flattened_tensor = torch.empty(total_numel, dtype=first_tensor.dtype, device=first_tensor.device) + + # Second pass: copy data directly into pre-allocated tensor + for meta, (name, tensor) in zip(self.metadata, named_tensors): + self.flattened_tensor[meta.start_idx:meta.end_idx].copy_(tensor.flatten()) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata') + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> Dict[str, torch.Tensor]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = {} + + for meta in self.metadata: + tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape) + dtype = getattr(torch, meta.dtype.split('.')[-1]) + # batch dtype conversion (if needed) + if tensor.dtype != dtype: + tensor = tensor.to(dtype) + + reconstructed[meta.name] = tensor + + return reconstructed + + def identity_data_collator(features): """Identity data collator that returns features as-is without any processing.""" return features @@ -563,3 +809,75 @@ def set_expandable_segments(enable: bool) -> None: if torch.cuda.is_available(): torch.cuda.memory._set_allocator_settings(f'expandable_segments:{enable}') os.environ['PYTORCH_CUDA_ALLOC_CONF'] = f'expandable_segments:{enable}' + + +def peft_config_to_dict(peft_config): + if not isinstance(peft_config, dict): + peft_config = asdict(peft_config) + # turn set to list to serializable + if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): + peft_config['target_modules'] = list(peft_config['target_modules']) + + return peft_config + + +def _create_parameter_buckets(named_params, bucket_size_mb=100): + """Create parameter buckets grouped by dtype for efficient processing""" + buckets = [] + current_bucket = [] + current_size = 0 + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + + # Group parameters by dtype first, then by size + dtype_groups = {} + for name, param in named_params: + dtype = param.dtype + if dtype not in dtype_groups: + dtype_groups[dtype] = [] + dtype_groups[dtype].append((name, param)) + + # Create buckets within each dtype group + for dtype, params in dtype_groups.items(): + for name, param in params: + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, start a new bucket + if current_size + param_size > bucket_size_bytes and current_bucket: + buckets.append(current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, param)) + current_size += param_size + + # Add remaining params in current bucket + if current_bucket: + buckets.append(current_bucket) + current_bucket = [] + current_size = 0 + + return buckets + + +def _process_bucket_with_flattened_tensor(trainer, bucket_params): + """Process a bucket of parameters using FlattenedTensorBucket for efficiency""" + if not bucket_params: + return + + # Create FlattenedTensorBucket for efficient processing + bucket = FlattenedTensorBucket(named_tensors=bucket_params) + metadatas = bucket.get_metadata() + flattened_tensor = bucket.get_flattened_tensor() + + # Use the new flattened parameter update method + # If not available, fall back to individual parameter updates + try: + trainer.vllm_client.update_flattened_params(metadatas, flattened_tensor) + except AttributeError: + # Fallback to individual parameter updates + reconstructed = bucket.reconstruct_tensors() + for name, param in reconstructed.items(): + trainer.vllm_client.update_named_param(name, param) + + # Clean up + del bucket, metadatas, flattened_tensor diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index fc3dc2e614..8440a220bb 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -4,13 +4,15 @@ import threading import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict from typing import List, Optional, Union from urllib.parse import urlparse +import json import requests import torch from packaging import version -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from requests import ConnectionError from torch import nn from transformers.utils import is_torch_cuda_available @@ -19,6 +21,7 @@ from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available +from .utils import peft_config_to_dict if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -242,6 +245,87 @@ def _update_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors: {all_errors}') + def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tensor): + """ + Adds a LoRA adapter to the model on all servers. + + Args: + lora_request: TensorLoRARequest object containing LoRA adapter information. + """ + errors = [None] * self.num_servers + peft_config = peft_config_to_dict(peft_config) + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + + def _update_single_server(i): + try: + data = { + 'lora_int_id': lora_int_id, + 'peft_config': { + **peft_config + }, + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_adapter_flattened_param/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update adapter failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + + def update_flattened_params(self, metadatas, flattened_tensor): + """ + Updates model parameters using flattened tensor data. + + Args: + metadatas: List of FlattenedTensorMetadata objects + flattened_tensor: The flattened tensor containing all parameters + """ + errors = [None] * self.num_servers + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + + def _update_single_server(i): + try: + data = { + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_flattened_params/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update flattened params failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): self.update_named_param(name, param.data) @@ -275,6 +359,7 @@ def get_engine_type(self): self.use_async_engine = result['engine_type'] == 'AsyncLLMEngine' self.enable_multi_turn = result.get('enable_multi_turn', False) self.use_gym_env = result.get('gym_env', False) + self.enable_lora = result.get('enable_lora', False) return result def close_communicator(self):