diff --git "a/docs/source/BestPractices/GPTQ\351\207\217\345\214\226\346\250\241\345\236\213GRPO\350\256\255\347\273\203.md" "b/docs/source/BestPractices/GPTQ\351\207\217\345\214\226\346\250\241\345\236\213GRPO\350\256\255\347\273\203.md" new file mode 100644 index 0000000000..aa9e7fc2b4 --- /dev/null +++ "b/docs/source/BestPractices/GPTQ\351\207\217\345\214\226\346\250\241\345\236\213GRPO\350\256\255\347\273\203.md" @@ -0,0 +1,88 @@ +# 采用Colocate模式进行GPTQ量化模型的GRPO训练 + +## 1. 问题和可能的解决方法 + +已知:采用vLLM加速时目前代码会合Lora再更新vllm服务的模型的参数,但是GPTQ量化模型无法合lora。 + +实际:采用VLLM加速,量化模型在move model to llm时会出错。报错:AttributeError: 'GPTQLoraLinear' object has no attribute 'get_delta_weight',同https://github.com/modelscope/ms-swift/issues/3949。 + +现在的框架只能在不采用VLLM推理加速的情况下训练,速度非常慢。(不考虑此方案) + +针对这个问题有两种解决方法: + +- 方案1:修改ms-swift,在move_model_to_vllm中改为每步暂存Lora参数到本地,调用LLM engine时通过Adapter-request参数传递lora参数 + +- 方案2:反量化GPTQ-int4模型,在此基础上进行训练,保存lora,最后基模采用量化版本的。 + +## 2. 方案2 + +针对方案2,优先测试了ms-swift能否支持非量化的32B模型的Lora模式的GRPO。发现: +- server模式下的VLLM不支持。在更新VLLM服务的模型的参数时会出错,报通信超时错误,同https://github.com/modelscope/ms-swift/issues/4797。 +- colocate模式下可以。 + +目前还没写出无误的GPTQ反量化代码,所以方案2暂时进行到这里。 + +## 3. 方案1 + +针对方案1,按想法修改了ms-swift的代码,并且通过了测试,完成了实验。 + +### 3.1 示例脚本 + +```bash +MASTER_PORT=29502 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NPROC_PER_NODE=4 \ +swift rlhf \ + --rlhf_type grpo \ + --model /xxx/deepseek-r1-distill-qwen-32b-gptq-int4 \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_xxx_accuracy external_xxx_format external_xxx_len \ + --reward_weights 1.0 1.0 1.0 \ + --vllm_mode colocate \ + --use_vllm true \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 4 \ + --torch_dtype bfloat16 \ + --dataset 'xxx/xxx.json' \ + --max_completion_length 5120 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --eval_steps 50 \ + --save_steps 50 \ + --save_total_limit 10 \ + --logging_steps 5 \ + --max_length 16000 \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 16 \ + --target_modules all-linear \ + --resume_only_model \ + --resume_from_checkpoint /xxx/checkpoint-xxx \ + --output_dir /xxx/xxx \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 0.7 \ + --top_p 1.0 \ + --top_k 80 \ + --log_completions true \ + --report_to tensorboard \ + --model_type deepseek_r1_distill \ + --async_generate false \ + --deepspeed zero3 \ + --sleep_level 1 \ + --max_step 1500 \ + --vllm_max_model_len 30000 \ + --local_adapter_path /xxx/tmp_path_for_lora \ + +``` +### 3.2 注意事项 + +- 需要注意,此时不能用move_model_batches这个参数,也就是不将lora参数分batch传给vllm,否则会报错[rank0]: IndexError: too many indices for tensor of dimension 1。 + +- 如果是继续训练,比如先前基于sft训练了lora,想在此lora上继续训练,采用GRPO方式。那么如果先前采用的deepspeed阶段是zero3, 那么此时需要采用同样的zero3。不能采用建议的zero3_offload 、offload_optimizer true 、offload_model true 策略,否则会报错[rank0]: KeyError: 'bias_correction' + +- 如果遇到爆显存问题,可调低vllm_gpu_memory_utilization,vllm_max_model_len等值。 diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 0ec531cfce..50ef234ba5 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -107,6 +107,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo undesirable_weight: float = 1.0 # PPO/GRPO/GKD temperature: float = 0.9 + local_adapter_path: str = None # RM center_rewards_coefficient: Optional[float] = None # GKD diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index c2caafcc9a..d8cf3fa35a 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -247,6 +247,7 @@ def get_vllm_engine_kwargs(self): @dataclass class GRPOArgumentsMixin(VllmArguments): + local_adapter_path: str = None epsilon: float = 0.2 epsilon_high: Optional[float] = None delta: Optional[float] = None diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index b5783e4301..f0f88dfca1 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -18,6 +18,9 @@ from types import MethodType from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +import shutil +from swift.llm.infer.infer_engine.utils import AdapterRequest + import json import torch import torch.nn as nn @@ -107,6 +110,8 @@ def __init__(self, from swift.trainers.rlhf_arguments import GRPOConfig args: GRPOConfig = kwargs['args'] self.args = args + self.local_adapter_path = getattr(args, 'local_adapter_path', None) + self.enable_lora = True if self.local_adapter_path else False self.ref_adapter_name = getattr(args, 'ref_adapter_name', None) self.model_adapter_name = None # for async generate @@ -529,6 +534,7 @@ def prepare_vllm(self, model): max_model_len=self.args.vllm_max_model_len, seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, template=self.template, + enable_lora = self.enable_lora, distributed_executor_backend='external_launcher', ) return engine @@ -568,34 +574,47 @@ def _move_model_to_vllm(self, skip_async_check=False): parameter for name, parameter in self.model.named_parameters() if not parameter_group or name in parameter_group ] - with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): - self.model.merge_adapter() - state_dict = self.model.state_dict() - state_dict = { - k.removeprefix('base_model.model.').replace('.base_layer', ''): v - for k, v in state_dict.items() - } - state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} - # When module to save, remove its prefix and discard the original module - state_dict = { - k.replace('modules_to_save.default.', ''): v - for k, v in state_dict.items() if 'original_module' not in k - } - if parameter_group_no_lora: - parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] - state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} - assert len(state_dict) > 0 and all( - [state.shape != torch.Size([0]) for state in state_dict.values()]) - - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - for name, param in state_dict.items(): - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights(state_dict.items()) - with patch_lora_unmerge(self.model): - self.model.unmerge_adapter() - del state_dict + # TODO save lora in local adapter path + if self.local_adapter_path: + with gather_if_zero3(parameters): + if self.accelerator.is_main_process: + if os.path.exists(self.local_adapter_path): + # delete existing files + shutil.rmtree(self.local_adapter_path) + logger.info(f"step:{self.state.global_step},deleted previous lora") + + os.makedirs(self.local_adapter_path) + self.model.save_pretrained(self.local_adapter_path,peft_format=True) + logger.info(f"step:{self.state.global_step},save newest lora in local adapter path") + else: + with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + self.model.merge_adapter() + state_dict = self.model.state_dict() + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + if parameter_group_no_lora: + parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + assert len(state_dict) > 0 and all( + [state.shape != torch.Size([0]) for state in state_dict.values()]) + + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + for name, param in state_dict.items(): + self.vllm_client.update_named_param(name, param) + elif self.vllm_mode == 'colocate': + llm_model = self.engine.inner_model + llm_model.load_weights(state_dict.items()) + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() + del state_dict else: for name, param in self.model.named_parameters(): with gather_if_zero3([param]): @@ -1949,7 +1968,17 @@ def _engine_infer( asdict(request_config), use_tqdm=use_tqdm) else: - res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + # use adapter_request path lora to vllm engine + if self.local_adapter_path: + if not os.path.exists(self.local_adapter_path): + raise FileNotFoundError(f'fpath: {self.local_adapter_path}') + tmp_name = "lora_"+str(self.state.global_step) + adapter_request = AdapterRequest(tmp_name, self.local_adapter_path) + if self.accelerator.is_main_process : + logger.info(f"adapter_request info:{adapter_request}") + res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm, adapter_request=adapter_request) + else: + res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) if all(isinstance(r, RolloutOutput) for r in res): return res else: