From afa55a2a6ff461e5a36754ae3ed1224c595e6df4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 17 Jul 2025 14:35:20 +0800 Subject: [PATCH 01/21] update --- swift/llm/argument/export_args.py | 3 +++ swift/llm/export/__init__.py | 1 + swift/llm/export/bin_dataset.py | 4 ++++ swift/llm/export/export.py | 3 +++ tests/train/test_export_bin_dataset.py | 8 ++++++++ 5 files changed, 19 insertions(+) create mode 100644 swift/llm/export/bin_dataset.py create mode 100644 tests/train/test_export_bin_dataset.py diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 370ffbc777..d888d769e6 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -40,6 +40,9 @@ class ExportArguments(MergeArguments, BaseArguments): quant_batch_size: int = 1 group_size: int = 128 + # bin_dataset + to_bin_dataset: bool = False + # ollama to_ollama: bool = False diff --git a/swift/llm/export/__init__.py b/swift/llm/export/__init__.py index e3330166f1..574c91589a 100644 --- a/swift/llm/export/__init__.py +++ b/swift/llm/export/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .bin_dataset import export_to_bin_dataset from .export import SwiftExport, export_main from .merge_lora import merge_lora from .ollama import export_to_ollama diff --git a/swift/llm/export/bin_dataset.py b/swift/llm/export/bin_dataset.py new file mode 100644 index 0000000000..059d573812 --- /dev/null +++ b/swift/llm/export/bin_dataset.py @@ -0,0 +1,4 @@ + + +def export_to_bin_dataset(args): + pass diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index d78658ba62..c72d93825e 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -4,6 +4,7 @@ from swift.llm import ExportArguments, SwiftPipeline from swift.tuners import swift_to_peft_format from swift.utils import get_logger +from .bin_dataset import export_to_bin_dataset from .merge_lora import merge_lora from .ollama import export_to_ollama from .quant import quantize_model @@ -29,6 +30,8 @@ def run(self): quantize_model(args) elif args.to_ollama: export_to_ollama(args) + elif args.to_bin_dataset: + export_to_bin_dataset(args) elif args.to_mcore: from swift.megatron import convert_hf2mcore convert_hf2mcore(args) diff --git a/tests/train/test_export_bin_dataset.py b/tests/train/test_export_bin_dataset.py new file mode 100644 index 0000000000..eabac00ad4 --- /dev/null +++ b/tests/train/test_export_bin_dataset.py @@ -0,0 +1,8 @@ + + +def test_export_bin_dataset(): + from swift.llm import export_main, ExportArguments + export_main(ExportArguments(model='Qwen/Qwen2.5-7B-Instruct', + dataset='AI-ModelScope/alpaca-gpt4-data-zh', + to_bin_dataset=True)) + print() From b3e883d32851f14c26acc5f517861594a57d21de Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 17 Jul 2025 14:55:34 +0800 Subject: [PATCH 02/21] update --- swift/llm/export/bin_dataset.py | 26 +++++++++++++++++++++++++- swift/llm/train/sft.py | 6 +++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/swift/llm/export/bin_dataset.py b/swift/llm/export/bin_dataset.py index 059d573812..487d60ac63 100644 --- a/swift/llm/export/bin_dataset.py +++ b/swift/llm/export/bin_dataset.py @@ -1,4 +1,28 @@ +from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset + +def encode_dataset(args, template, dataset): + if args.packing: + dataset = PackingDataset( + template, + dataset, + num_proc=args.dataset_num_proc, + strict=args.strict, + load_from_cache_file=args.load_from_cache_file) + else: + preprocessor = EncodePreprocessor(template=template) + dataset = preprocessor( + dataset, + num_proc=args.dataset_num_proc, + load_from_cache_file=args.load_from_cache_file, + strict=args.strict) + return dataset def export_to_bin_dataset(args): - pass + _, processor = args.get_model_processor(load_model=False) + template = args.get_template(processor) + from ..train import SwiftSft + train_dataset, val_dataset = SwiftSft._get_dataset(args) + + for dataset_type, dataset in [('train', train_dataset), ('val', val_dataset)]: + dataset = encode_dataset(args, template, dataset) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 4c692431bd..2dac29c1b4 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -58,9 +58,9 @@ def _prepare_template(self) -> None: template.model = self.model self.template = template - def _get_dataset(self): + @staticmethod + def _get_dataset(args): # The random shuffling of the training set occurs in the dataloader of the trainer. - args = self.args dataset_kwargs = args.get_dataset_kwargs() train_dataset, val_dataset = load_dataset( args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) @@ -92,7 +92,7 @@ def _save_val_dataset(self, val_dataset): def run(self): args = self.args - train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._get_dataset(args) train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) if args.task_type == 'seq_cls': From ae035b5e258775ea72440f70131472e2782a05d5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 17 Jul 2025 23:32:36 +0800 Subject: [PATCH 03/21] update --- swift/llm/export/bin_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/llm/export/bin_dataset.py b/swift/llm/export/bin_dataset.py index 487d60ac63..df41bb3903 100644 --- a/swift/llm/export/bin_dataset.py +++ b/swift/llm/export/bin_dataset.py @@ -21,6 +21,7 @@ def encode_dataset(args, template, dataset): def export_to_bin_dataset(args): _, processor = args.get_model_processor(load_model=False) template = args.get_template(processor) + template.set_mode('train') from ..train import SwiftSft train_dataset, val_dataset = SwiftSft._get_dataset(args) From 36ec88efcf36650b2b9d90a292cc3d38592a9ffd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 17 Jul 2025 23:32:53 +0800 Subject: [PATCH 04/21] update --- tests/train/test_export_bin_dataset.py | 30 ++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/train/test_export_bin_dataset.py b/tests/train/test_export_bin_dataset.py index eabac00ad4..c28ce5c32c 100644 --- a/tests/train/test_export_bin_dataset.py +++ b/tests/train/test_export_bin_dataset.py @@ -1,8 +1,30 @@ - def test_export_bin_dataset(): from swift.llm import export_main, ExportArguments - export_main(ExportArguments(model='Qwen/Qwen2.5-7B-Instruct', - dataset='AI-ModelScope/alpaca-gpt4-data-zh', - to_bin_dataset=True)) + export_main( + ExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT', + to_bin_dataset=True, + dataset_num_proc=32, + packing=True, + load_from_cache_file=False, + )) print() + +def test_sft(): + from swift.llm import sft_main, TrainArguments + sft_main(TrainArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT#1000', + dataset_num_proc=2, + packing=True, + attn_impl='flash_attn', + # load_from_cache_file=False, + )) + +if __name__ == '__main__': + # test_export_bin_dataset() + test_sft() +# 186426.75 +# 13896.07 From bc55ba4b44671c18e155f2bfbc5af95cbd2616f5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 19 Jul 2025 20:19:20 +0800 Subject: [PATCH 05/21] update --- swift/llm/train/sft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index f1419fbf50..094824cab4 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -58,8 +58,8 @@ def _prepare_template(self) -> None: template.model = self.model self.template = template - @staticmethod - def _get_dataset(args): + def _get_dataset(self, args=None): + args = args or self.args # The random shuffling of the training set occurs in the dataloader of the trainer. dataset_kwargs = args.get_dataset_kwargs() train_dataset, val_dataset = load_dataset( @@ -92,7 +92,7 @@ def _save_val_dataset(self, val_dataset): def run(self): args = self.args - train_dataset, val_dataset = self._get_dataset(args) + train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) if args.task_type == 'seq_cls': From 344a83cd69df148e03f3529a28c754bd8e78127a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Jul 2025 22:19:16 +0800 Subject: [PATCH 06/21] update --- swift/llm/argument/export_args.py | 6 ++- swift/llm/export/__init__.py | 2 +- swift/llm/export/bin_dataset.py | 29 -------------- swift/llm/export/cached_dataset.py | 31 +++++++++++++++ swift/llm/export/export.py | 6 +-- swift/llm/train/sft.py | 63 ++++++++++++++++++------------ swift/megatron/train/sft.py | 3 +- 7 files changed, 78 insertions(+), 62 deletions(-) delete mode 100644 swift/llm/export/bin_dataset.py create mode 100644 swift/llm/export/cached_dataset.py diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index d888d769e6..1b6539e57c 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -40,8 +40,8 @@ class ExportArguments(MergeArguments, BaseArguments): quant_batch_size: int = 1 group_size: int = 128 - # bin_dataset - to_bin_dataset: bool = False + # cached_dataset + to_cached_dataset: bool = False # ollama to_ollama: bool = False @@ -82,6 +82,8 @@ def _init_output_dir(self): suffix = 'mcore' elif self.to_hf: suffix = 'hf' + elif self.to_cached_dataset: + suffix = 'cached_dataset' else: return diff --git a/swift/llm/export/__init__.py b/swift/llm/export/__init__.py index 574c91589a..6012848527 100644 --- a/swift/llm/export/__init__.py +++ b/swift/llm/export/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .bin_dataset import export_to_bin_dataset +from .cached_dataset import export_cached_dataset from .export import SwiftExport, export_main from .merge_lora import merge_lora from .ollama import export_to_ollama diff --git a/swift/llm/export/bin_dataset.py b/swift/llm/export/bin_dataset.py deleted file mode 100644 index df41bb3903..0000000000 --- a/swift/llm/export/bin_dataset.py +++ /dev/null @@ -1,29 +0,0 @@ - -from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset - -def encode_dataset(args, template, dataset): - if args.packing: - dataset = PackingDataset( - template, - dataset, - num_proc=args.dataset_num_proc, - strict=args.strict, - load_from_cache_file=args.load_from_cache_file) - else: - preprocessor = EncodePreprocessor(template=template) - dataset = preprocessor( - dataset, - num_proc=args.dataset_num_proc, - load_from_cache_file=args.load_from_cache_file, - strict=args.strict) - return dataset - -def export_to_bin_dataset(args): - _, processor = args.get_model_processor(load_model=False) - template = args.get_template(processor) - template.set_mode('train') - from ..train import SwiftSft - train_dataset, val_dataset = SwiftSft._get_dataset(args) - - for dataset_type, dataset in [('train', train_dataset), ('val', val_dataset)]: - dataset = encode_dataset(args, template, dataset) diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py new file mode 100644 index 0000000000..54c5e86377 --- /dev/null +++ b/swift/llm/export/cached_dataset.py @@ -0,0 +1,31 @@ + +from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset +from swift.llm import ExportArguments +from swift.llm.train import SwiftSft +from typing import Union, List + + +class ExportCachedDataset(SwiftSft): + args_class = ExportArguments + args: args_class + + def __init__(self, args: Union[List[str], ExportArguments, None] = None) -> None: + super(SwiftSft, self).__init__(args) + self.train_msg = {} # dummy + self.processor = None + self._prepare_template() + self._prepare_model_tokenizer(load_model=self.template.use_model) + self.template.init_processor(self.processor) + + def _save_val_dataset(self, val_dataset): + pass + + def run(self): + self.args.lazy_tokenize = False + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + return + + +def export_cached_dataset(args: Union[List[str], ExportArguments, None] = None): + return ExportCachedDataset(args).main() diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index c72d93825e..d82d755d92 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -4,7 +4,7 @@ from swift.llm import ExportArguments, SwiftPipeline from swift.tuners import swift_to_peft_format from swift.utils import get_logger -from .bin_dataset import export_to_bin_dataset +from .cached_dataset import export_cached_dataset from .merge_lora import merge_lora from .ollama import export_to_ollama from .quant import quantize_model @@ -30,8 +30,8 @@ def run(self): quantize_model(args) elif args.to_ollama: export_to_ollama(args) - elif args.to_bin_dataset: - export_to_bin_dataset(args) + elif args.to_cached_dataset: + export_cached_dataset(args) elif args.to_mcore: from swift.megatron import convert_hf2mcore convert_hf2mcore(args) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index c6e20c74f2..d361aaf1ae 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -24,6 +24,8 @@ class SwiftSft(SwiftPipeline, TunerMixin): def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: super().__init__(args) self.train_msg = {} + args = self.args + self.is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' self._prepare_model_tokenizer() self._prepare_template() self._prepare_callbacks() @@ -35,13 +37,14 @@ def _prepare_generation_config(self): args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') - def _prepare_model_tokenizer(self): + def _prepare_model_tokenizer(self, load_model=True): args = self.args if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.init_sequence_parallel(args.sequence_parallel_size) - self.model, self.processor = args.get_model_processor() - + self.model, self.processor = args.get_model_processor(load_model=load_model) + if self.model is None: + return if hasattr(self.model, 'hf_device_map'): logger.info(f'model.hf_device_map: {self.model.hf_device_map}') @@ -57,9 +60,9 @@ def _prepare_template(self) -> None: template.model = self.model self.template = template - def _get_dataset(self, args=None): - args = args or self.args + def _get_dataset(self): # The random shuffling of the training set occurs in the dataloader of the trainer. + args = self.args dataset_kwargs = args.get_dataset_kwargs() train_dataset, val_dataset = load_dataset( args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) @@ -88,11 +91,19 @@ def _save_val_dataset(self, val_dataset): append_to_jsonl(val_dataset_path, val_dataset.to_list()) logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') + def _prepare_dataset(self): + if args.cached_dataset: + train_dataset, val_dataset = self._get_cached_dataset() + else: + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + if not self.is_grpo: + self._show_dataset(train_dataset, val_dataset) + return train_dataset, val_dataset + def run(self): args = self.args - - train_dataset, val_dataset = self._get_dataset() - train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + train_dataset, val_dataset = self._prepare_dataset() if args.task_type == 'seq_cls': args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) @@ -216,14 +227,29 @@ def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]): logger.info(f'Dataset Token Length: {stat_str}') return stat_str + def _show_dataset(self, train_dataset, val_dataset): + args = self.args + if is_master(): + inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) + self.template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) + elif hasattr(train_dataset, '__len__'): + # Avoid the random mismatch issue in LazyLLMDataset. + inputs = train_dataset[0] + if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) == 0: + val_dataset = None + if not args.lazy_tokenize and not args.streaming: + self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) + if val_dataset is not None and not predict_with_generate: + self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) + def _encode_dataset(self, train_dataset, val_dataset): template = self.template args = self.args self._save_val_dataset(val_dataset) - is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] - if not is_grpo: + if not self.is_grpo: origin_template_model = template.model template.model = None # Avoid serializing the model. lazy_tokenize = args.lazy_tokenize and not args.packing @@ -238,7 +264,7 @@ def _encode_dataset(self, train_dataset, val_dataset): elif args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset dataset = packing_dataset_cls( - self.template, + template, dataset, num_proc=args.dataset_num_proc, strict=args.strict, @@ -254,21 +280,8 @@ def _encode_dataset(self, train_dataset, val_dataset): dataset = LazyLLMDataset(dataset, template.encode) datasets[i] = dataset template.model = origin_template_model - train_dataset, val_dataset = datasets - if is_master(): - inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) - template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) - elif hasattr(train_dataset, '__len__'): - # Avoid the random mismatch issue in LazyLLMDataset. - inputs = train_dataset[0] - if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) == 0: - val_dataset = None - if not args.lazy_tokenize and not args.streaming: - self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) - if val_dataset is not None and not predict_with_generate: - self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) - return train_dataset, val_dataset + return datasets def sft_main(args: Union[List[str], TrainArguments, None] = None): diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 7b758ad2fa..986d002786 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -48,9 +48,8 @@ def _get_data_collator(self): def run(self): args = self.args + train_dataset, val_dataset = self._prepare_dataset() data_collator = self._get_data_collator() - train_dataset, val_dataset = self._get_dataset() - train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) if args.streaming: train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) From 4b3c2ddb82ea389a29723461aa0ba2d6caf09756 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Jul 2025 22:19:58 +0800 Subject: [PATCH 07/21] update --- swift/llm/train/sft.py | 2 +- tests/train/test_export_bin_dataset.py | 30 ----------------------- tests/train/test_export_cached_dataset.py | 30 +++++++++++++++++++++++ 3 files changed, 31 insertions(+), 31 deletions(-) delete mode 100644 tests/train/test_export_bin_dataset.py create mode 100644 tests/train/test_export_cached_dataset.py diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index d361aaf1ae..381683ce22 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -246,7 +246,7 @@ def _encode_dataset(self, train_dataset, val_dataset): template = self.template args = self.args self._save_val_dataset(val_dataset) - + predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] if not self.is_grpo: diff --git a/tests/train/test_export_bin_dataset.py b/tests/train/test_export_bin_dataset.py deleted file mode 100644 index c28ce5c32c..0000000000 --- a/tests/train/test_export_bin_dataset.py +++ /dev/null @@ -1,30 +0,0 @@ - -def test_export_bin_dataset(): - from swift.llm import export_main, ExportArguments - export_main( - ExportArguments( - model='Qwen/Qwen2.5-7B-Instruct', - dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT', - to_bin_dataset=True, - dataset_num_proc=32, - packing=True, - load_from_cache_file=False, - )) - print() - -def test_sft(): - from swift.llm import sft_main, TrainArguments - sft_main(TrainArguments( - model='Qwen/Qwen2.5-7B-Instruct', - dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT#1000', - dataset_num_proc=2, - packing=True, - attn_impl='flash_attn', - # load_from_cache_file=False, - )) - -if __name__ == '__main__': - # test_export_bin_dataset() - test_sft() -# 186426.75 -# 13896.07 diff --git a/tests/train/test_export_cached_dataset.py b/tests/train/test_export_cached_dataset.py new file mode 100644 index 0000000000..79967b3675 --- /dev/null +++ b/tests/train/test_export_cached_dataset.py @@ -0,0 +1,30 @@ +def test_export_cached_dataset(): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT', + to_cached_dataset=True, + dataset_num_proc=4, + # packing=True, + load_from_cache_file=False, + )) + print() + + +def test_sft(): + from swift.llm import sft_main, TrainArguments + sft_main( + TrainArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT#1000', + dataset_num_proc=2, + packing=True, + attn_impl='flash_attn', + # load_from_cache_file=False, + )) + + +if __name__ == '__main__': + test_export_cached_dataset() + # test_sft() From df2a0cbf2e6a272a0195395553283f3e02fe9123 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 29 Jul 2025 15:18:11 +0800 Subject: [PATCH 08/21] update --- swift/llm/argument/base_args/base_args.py | 20 +++- swift/llm/argument/train_args.py | 13 --- swift/llm/dataset/utils.py | 7 +- swift/llm/export/cached_dataset.py | 17 ++-- swift/llm/infer/infer.py | 3 +- swift/llm/template/base.py | 7 +- swift/llm/train/sft.py | 110 +++++++++++++--------- swift/megatron/argument/train_args.py | 2 - 8 files changed, 105 insertions(+), 74 deletions(-) diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 93434bb11a..7eea10da45 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -63,6 +63,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat seed (int): Random seed for reproducibility. Default is 42. model_kwargs (Optional[str]): Additional keyword arguments for the model. Default is None. load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False. + packing (bool): Flag to enable packing of datasets. Default is False. + lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. use_hf (bool): Flag to determine if Hugging Face should be used. Default is False. hub_token (Optional[str]): SDK token for authentication. Default is None. custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None. @@ -80,6 +82,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat load_data_args: bool = False # dataset packing: bool = False + lazy_tokenize: Optional[bool] = None + cached_dataset: List[str] = field(default_factory=list) custom_register_path: List[str] = field(default_factory=list) # .py # hub use_hf: bool = False @@ -97,6 +101,18 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: pass + def _init_lazy_tokenize(self): + if self.streaming and self.lazy_tokenize: + self.lazy_tokenize = False + logger.warning('Streaming and lazy_tokenize are incompatible. ' + f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') + if self.lazy_tokenize is None: + if self.model_meta.is_multimodal and not self.streaming and not self.packing: + self.lazy_tokenize = True + else: + self.lazy_tokenize = False + logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') + def _init_custom_register(self) -> None: """Register custom .py file to datasets""" if isinstance(self.custom_register_path, str): @@ -154,7 +170,9 @@ def __post_init__(self): QuantizeArguments.__post_init__(self) TemplateArguments.__post_init__(self) DataArguments.__post_init__(self) - + if isinstance(self.cached_dataset, str): + self.cached_dataset = [self.cached_dataset] + self._init_lazy_tokenize() self.hub = get_hub(self.use_hf) if self.hub.try_login(self.hub_token): logger.info('hub login successful!') diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 872eb6d57b..33629cb13b 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -109,8 +109,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra Args: add_version (bool): Flag to add version information to output_dir. Default is True. loss_type (Optional[str]): Type of loss function to use. Default is None. - packing (bool): Flag to enable packing of datasets. Default is False. - lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. max_new_tokens (int): Maximum number of new tokens to generate. Default is 64. temperature (float): Temperature for sampling. Default is 0. optimizer (Optional[str]): Optimizer type to use, define it in the plugin package. Default is None. @@ -118,7 +116,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra """ add_version: bool = True create_checkpoint_symlink: bool = False - lazy_tokenize: Optional[bool] = None # plugin loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) @@ -135,15 +132,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra # auto_tp deepspeed_autotp_size: Optional[int] = None - def _init_lazy_tokenize(self): - if self.streaming and self.lazy_tokenize: - self.lazy_tokenize = False - logger.warning('Streaming and lazy_tokenize are incompatible. ' - f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') - if self.lazy_tokenize is None: - self.lazy_tokenize = self.model_meta.is_multimodal and not self.streaming - logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') - def __post_init__(self) -> None: if self.padding_free or self.packing: if self.packing: @@ -179,7 +167,6 @@ def __post_init__(self) -> None: self._init_deepspeed() self._init_device() - self._init_lazy_tokenize() if getattr(self, 'accelerator_config', None) is None: self.accelerator_config = {'dispatch_batches': False} diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index ae03a60659..9739eff032 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -146,11 +146,6 @@ def __init__( self.strict = strict self.load_from_cache_file = load_from_cache_file self.workers = [] - preprocessor = EncodePreprocessor(template=template) - self.dataset = preprocessor( - dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) - if template.model_meta.is_multimodal: - self.dataset = LazyLLMDataset(self.dataset, encode_func=template.encode) self.packed_idx = self.create_packed_idx() if is_master() else None if dist.is_initialized() and is_dist(): obj_list = [self.packed_idx] @@ -286,7 +281,7 @@ def __init__(self, template: 'Template'): self.is_multimodal = template.model_meta.is_multimodal def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: - encoded = self.template.encode(row) + encoded = self.template.encode(row, return_length=True) if self.is_multimodal: row['length'] = encoded['length'] encoded = row diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py index 54c5e86377..2e397a59bb 100644 --- a/swift/llm/export/cached_dataset.py +++ b/swift/llm/export/cached_dataset.py @@ -1,8 +1,12 @@ +import os +from typing import List, Union -from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset from swift.llm import ExportArguments from swift.llm.train import SwiftSft -from typing import Union, List +from swift.utils import get_logger +from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset + +logger = get_logger() class ExportCachedDataset(SwiftSft): @@ -17,14 +21,15 @@ def __init__(self, args: Union[List[str], ExportArguments, None] = None) -> None self._prepare_model_tokenizer(load_model=self.template.use_model) self.template.init_processor(self.processor) - def _save_val_dataset(self, val_dataset): - pass - def run(self): self.args.lazy_tokenize = False train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) - return + self._show_dataset(train_dataset, val_dataset) + train_dataset.save_to_disk(os.path.join(self.args.output_dir, 'train')) + if val_dataset is not None: + val_dataset.save_to_disk(os.path.join(self.args.output_dir, 'val')) + logger.info(f'Dataset saved to `{self.args.output_dir}`') def export_cached_dataset(args: Union[List[str], ExportArguments, None] = None): diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 5a859d8e14..e0eca6a5ca 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -250,7 +250,8 @@ def infer_dataset(self) -> List[Dict[str, Any]]: prog_bar.close() metrics = self.infer_kwargs.pop('metrics') if result_list: - print(f'[rank{args.rank}] {metrics[0].compute()}') + metric = metrics[0].compute() + print(f'[rank{args.rank}] {metric}' if args.rank >= 0 else str(metric)) if args.metric is not None: self._calc_metric() return result_list diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index ee504c72e5..1e1032511e 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -86,7 +86,7 @@ def __init__( from .template_meta import TemplateMeta from swift.plugin import agent_templates, loss_scale_map self._processor_inited = False - self._version = 'v1' # Avoid compatibility issues caused by load_from_cache_file caching. + self._version = 'v2' # Avoid compatibility issues caused by load_from_cache_file caching. self.max_length = max_length self.model = None @@ -488,7 +488,8 @@ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @torch.inference_mode() def encode(self, inputs: Union[TemplateInputs, Dict[str, Any], InferRequest], - return_template_inputs: bool = False) -> Dict[str, Any]: + return_template_inputs: bool = False, + return_length: bool = False) -> Dict[str, Any]: """The entrance method of Template! Returns: @@ -537,7 +538,7 @@ def encode(self, lengths.append(value) elif isinstance(value, (tuple, list)): lengths += value - if self.is_training: + if return_length: encoded['length'] = max(lengths) else: encoded.pop('length', None) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 381683ce22..cbf997796b 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -4,7 +4,9 @@ from typing import List, Union from datasets import Dataset as HfDataset +from datasets import load_from_disk +from swift.llm.dataset.loader import DatasetLoader from swift.plugin import extra_callbacks, get_loss_func, get_metric from swift.trainers import TrainerFactory from swift.utils import append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array @@ -25,7 +27,6 @@ def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: super().__init__(args) self.train_msg = {} args = self.args - self.is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' self._prepare_model_tokenizer() self._prepare_template() self._prepare_callbacks() @@ -91,15 +92,55 @@ def _save_val_dataset(self, val_dataset): append_to_jsonl(val_dataset_path, val_dataset.to_list()) logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') + def _get_cached_dataset(self): + args = self.args + train_datasets, val_datasets = [], [] + for cached_dataset in args.cached_dataset: + train_path = os.path.join(cached_dataset, 'train') + val_path = os.path.join(cached_dataset, 'val') + train_datasets.append(load_from_disk(train_path)) + if os.path.exists(val_path): + val_datasets.append(load_from_disk(val_path)) + return train_datasets, val_datasets + def _prepare_dataset(self): + args = self.args if args.cached_dataset: - train_dataset, val_dataset = self._get_cached_dataset() + train_datasets, val_datasets = self._get_cached_dataset() else: + train_datasets, val_datasets = [], [] + if args.dataset: train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) - if not self.is_grpo: - self._show_dataset(train_dataset, val_dataset) - return train_dataset, val_dataset + train_datasets.append(train_dataset) + val_datasets.append(val_dataset) + train_dataset = DatasetLoader._concat_datasets(train_datasets) + val_dataset = DatasetLoader._concat_datasets(val_datasets) + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + datasets = [train_dataset, val_dataset] + if is_grpo: + return datasets + for i, dataset in enumerate(datasets): + if dataset is None: + continue + if i == 1 and predict_with_generate: + # val_dataset + continue + if lazy_tokenize: + dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) + elif args.packing: + packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset + dataset = packing_dataset_cls( + template, + dataset, + num_proc=args.dataset_num_proc, + strict=args.strict, + load_from_cache_file=args.load_from_cache_file) + elif args.model_meta.is_multimodal: + dataset = LazyLLMDataset(dataset, template.encode) + datasets[i] = dataset + self._show_dataset(*datasets) + return datasets def run(self): args = self.args @@ -215,12 +256,7 @@ def _prepare_callbacks(self): def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]): if isinstance(dataset, HfDataset): - # TODO: Temporary fix; awaiting template refactor. - try: - length = dataset['length'] - except KeyError: - logger.warning_once("The HfDataset is missing the 'length' column, skipping statistics.") - return + length = dataset['length'] else: length = dataset.dataset['length'] _, stat_str = stat_array(length) @@ -249,37 +285,27 @@ def _encode_dataset(self, train_dataset, val_dataset): predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] - if not self.is_grpo: - origin_template_model = template.model - template.model = None # Avoid serializing the model. - lazy_tokenize = args.lazy_tokenize and not args.packing - for i, dataset in enumerate(datasets): - if dataset is None: - continue - if i == 1 and predict_with_generate: - # val_dataset - continue - if lazy_tokenize: - dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) - elif args.packing: - packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset - dataset = packing_dataset_cls( - template, - dataset, - num_proc=args.dataset_num_proc, - strict=args.strict, - load_from_cache_file=args.load_from_cache_file) - else: - preprocessor = EncodePreprocessor(template=template) - dataset = preprocessor( - dataset, - num_proc=args.dataset_num_proc, - load_from_cache_file=args.load_from_cache_file, - strict=args.strict) - if args.model_meta.is_multimodal: - dataset = LazyLLMDataset(dataset, template.encode) - datasets[i] = dataset - template.model = origin_template_model + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + if is_grpo: + return datasets + + origin_template_model = template.model + template.model = None # Avoid serializing the model. + for i, dataset in enumerate(datasets): + if dataset is None: + continue + if i == 1 and predict_with_generate: + # val_dataset + continue + if not lazy_tokenize and not args.streaming: + preprocessor = EncodePreprocessor(template=template) + dataset = preprocessor( + dataset, + num_proc=args.dataset_num_proc, + load_from_cache_file=args.load_from_cache_file, + strict=args.strict) + datasets[i] = dataset + template.model = origin_template_model return datasets diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 928d87086f..7fa33c90c4 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -15,8 +15,6 @@ @dataclass class MegatronTrainArguments(MegatronArguments, BaseArguments): add_version: bool = True - # dataset - lazy_tokenize: bool = False def init_model_args(self, tokenizer, config): self.megatron_model_meta = get_megatron_model_meta(self.model_type) From b196d2f0e71def9ee0a5fafb79d2cc8104075d3d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 29 Jul 2025 15:28:02 +0800 Subject: [PATCH 09/21] fix --- swift/llm/export/cached_dataset.py | 2 +- swift/llm/train/sft.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py index 2e397a59bb..a9ecff4d5d 100644 --- a/swift/llm/export/cached_dataset.py +++ b/swift/llm/export/cached_dataset.py @@ -1,10 +1,10 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import List, Union from swift.llm import ExportArguments from swift.llm.train import SwiftSft from swift.utils import get_logger -from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset logger = get_logger() diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index cbf997796b..7f7ca8a7f7 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -26,7 +26,6 @@ class SwiftSft(SwiftPipeline, TunerMixin): def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: super().__init__(args) self.train_msg = {} - args = self.args self._prepare_model_tokenizer() self._prepare_template() self._prepare_callbacks() @@ -117,16 +116,18 @@ def _prepare_dataset(self): train_dataset = DatasetLoader._concat_datasets(train_datasets) val_dataset = DatasetLoader._concat_datasets(val_datasets) is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] if is_grpo: return datasets + template = self.template for i, dataset in enumerate(datasets): if dataset is None: continue if i == 1 and predict_with_generate: # val_dataset continue - if lazy_tokenize: + if args.lazy_tokenize: dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) elif args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset @@ -254,7 +255,8 @@ def _prepare_callbacks(self): callbacks += extra_callbacks self.callbacks = callbacks - def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]): + @staticmethod + def _stat_dataset(dataset: Union[HfDataset, PackingDataset]): if isinstance(dataset, HfDataset): length = dataset['length'] else: @@ -265,6 +267,7 @@ def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]): def _show_dataset(self, train_dataset, val_dataset): args = self.args + predict_with_generate = getattr(args, 'predict_with_generate', False) if is_master(): inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) self.template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) @@ -283,9 +286,9 @@ def _encode_dataset(self, train_dataset, val_dataset): args = self.args self._save_val_dataset(val_dataset) + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] - is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' if is_grpo: return datasets @@ -297,7 +300,7 @@ def _encode_dataset(self, train_dataset, val_dataset): if i == 1 and predict_with_generate: # val_dataset continue - if not lazy_tokenize and not args.streaming: + if not args.lazy_tokenize and not args.streaming: preprocessor = EncodePreprocessor(template=template) dataset = preprocessor( dataset, From d48a02d5b775f52ba2e9850f8473621fcdc1e621 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 29 Jul 2025 16:05:41 +0800 Subject: [PATCH 10/21] update --- swift/llm/dataset/utils.py | 21 ++++++++++----------- swift/llm/export/cached_dataset.py | 3 +-- swift/llm/template/base.py | 16 +++++++++------- swift/llm/template/template/qwen.py | 4 ++-- swift/llm/train/sft.py | 9 ++++----- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index 9739eff032..3c5de3343e 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -146,18 +146,18 @@ def __init__( self.strict = strict self.load_from_cache_file = load_from_cache_file self.workers = [] - self.packed_idx = self.create_packed_idx() if is_master() else None + self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else None if dist.is_initialized() and is_dist(): - obj_list = [self.packed_idx] + obj_list = [(self.packed_idx, self.packed_length)] dist.broadcast_object_list(obj_list) - self.packed_idx = obj_list[0] + self.packed_idx, self.packed_length = obj_list[0] def create_packed_idx(self): lengths = self.dataset['length'] data = [(i, length) for i, length in enumerate(lengths)] i = 0 PACKING_BATCH_SIZE = 1000 - input_data, res = [], [] + input_data, packed_idx, packed_length = [], [], [] with tqdm(total=len(data), dynamic_ncols=True, desc='Packing: ') as prog_bar: while True: new_data = data[i:i + PACKING_BATCH_SIZE] @@ -168,14 +168,13 @@ def create_packed_idx(self): i += PACKING_BATCH_SIZE is_finished = i >= len(data) sequences, input_data = calculate_matched_group(self.template, input_data, is_finished=is_finished) - res += sequences - return res + packed_idx += [[x[0] for x in seq] for seq in sequences] + packed_length += [sum(x[1] for x in seq) for seq in sequences] + return packed_idx, packed_length def __getitem__(self, index): sequence = self.packed_idx[index] - row = [] - for i, length in sequence: - row.append((self.dataset[i], length)) + row = [self.dataset[i] for i in sequence] return self.template.packing_row(row) def __len__(self): @@ -216,7 +215,7 @@ def _processor(self): i, data = self._in_queue.get() encoded_data = {} try: - encoded_data = self.template.encode(data) + encoded_data = self.template.encode(data, return_length=True) except Exception as e: if self.strict and not isinstance(e, MaxLengthError): raise @@ -266,7 +265,7 @@ def __iter__(self): sequences, data = calculate_matched_group(self.template, data, is_finished=finished) res = [] for row in sequences: - packed = self.template.packing_row(row) + packed = self.template.packing_row([r[0] for r in row]) res.append(packed) yield from res if finished: diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py index a9ecff4d5d..5155b83bbf 100644 --- a/swift/llm/export/cached_dataset.py +++ b/swift/llm/export/cached_dataset.py @@ -21,8 +21,7 @@ def __init__(self, args: Union[List[str], ExportArguments, None] = None) -> None self._prepare_model_tokenizer(load_model=self.template.use_model) self.template.init_processor(self.processor) - def run(self): - self.args.lazy_tokenize = False + def main(self): train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) self._show_dataset(train_dataset, val_dataset) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 1e1032511e..5ea56ca49f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -548,22 +548,24 @@ def encode(self, encoded['_extra_kwargs'] = extra_kwargs return encoded - def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: packed = {} keys = set() + length = [] for r in row: - keys.update(r[0].keys()) + keys.update(r.keys()) + length.append(r['length']) for key in keys: if key in {'input_ids', 'labels', 'loss_scale'}: - packed[key] = sum((x[0][key] for x in row), start=[]) + packed[key] = sum((x[key] for x in row), start=[]) elif key == 'length': - packed[key] = sum((x[0][key] for x in row)) + packed[key] = sum((x[key] for x in row)) elif key == 'channel': - packed[key] = [x[0][key] for x in row] + packed[key] = [x[key] for x in row] if 'position_ids' not in packed: - packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[]) + packed['position_ids'] = sum((list(range(x)) for x in length), start=[]) - packed.update(self._data_collator_mm_data([r[0] for r in row])) + packed.update(self._data_collator_mm_data(row)) return packed def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index e3a3571016..4c4a761851 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -381,10 +381,10 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: res[f'{media_type}_grid_thw'] = grid_thw return res - def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: position_ids = [] for r in row: - r = r[0].copy() + r = r.copy() r['input_ids'] = torch.tensor(r['input_ids'])[None] position_ids.append(self._get_position_ids(r)) packed = super().packing_row(row) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 7f7ca8a7f7..cd0aef5eb0 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -93,6 +93,7 @@ def _save_val_dataset(self, val_dataset): def _get_cached_dataset(self): args = self.args + assert not args.streaming and not args.lazy_tokenize train_datasets, val_datasets = [], [] for cached_dataset in args.cached_dataset: train_path = os.path.join(cached_dataset, 'train') @@ -127,9 +128,9 @@ def _prepare_dataset(self): if i == 1 and predict_with_generate: # val_dataset continue - if args.lazy_tokenize: + if (args.model_meta.is_multimodal or args.lazy_tokenize) and not args.streaming: dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) - elif args.packing: + if args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset dataset = packing_dataset_cls( template, @@ -137,8 +138,6 @@ def _prepare_dataset(self): num_proc=args.dataset_num_proc, strict=args.strict, load_from_cache_file=args.load_from_cache_file) - elif args.model_meta.is_multimodal: - dataset = LazyLLMDataset(dataset, template.encode) datasets[i] = dataset self._show_dataset(*datasets) return datasets @@ -260,7 +259,7 @@ def _stat_dataset(dataset: Union[HfDataset, PackingDataset]): if isinstance(dataset, HfDataset): length = dataset['length'] else: - length = dataset.dataset['length'] + length = dataset.packed_length _, stat_str = stat_array(length) logger.info(f'Dataset Token Length: {stat_str}') return stat_str From 543e91c0df7da63124092eed56ede5ac6c1a9b3f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 29 Jul 2025 16:06:13 +0800 Subject: [PATCH 11/21] update --- tests/train/test_export_cached_dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/train/test_export_cached_dataset.py b/tests/train/test_export_cached_dataset.py index 79967b3675..02e9006e16 100644 --- a/tests/train/test_export_cached_dataset.py +++ b/tests/train/test_export_cached_dataset.py @@ -6,8 +6,6 @@ def test_export_cached_dataset(): dataset='swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT', to_cached_dataset=True, dataset_num_proc=4, - # packing=True, - load_from_cache_file=False, )) print() @@ -21,10 +19,9 @@ def test_sft(): dataset_num_proc=2, packing=True, attn_impl='flash_attn', - # load_from_cache_file=False, )) if __name__ == '__main__': - test_export_cached_dataset() - # test_sft() + # test_export_cached_dataset() + test_sft() From cd0b4605c4478b85f832a47eb2199b07fc5c3d84 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 00:09:13 +0800 Subject: [PATCH 12/21] update --- ...6\213\345\222\214\346\225\260\346\215\256\351\233\206.md" | 1 + docs/source_en/Instruction/Supported-models-and-datasets.md | 1 + swift/llm/model/model/mistral.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 3633821800..0e2ef01d0b 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -487,6 +487,7 @@ |[AI-ModelScope/Mistral-7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)| |[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf)|mistral|llama|transformers>=4.34|✘|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |[swift/Codestral-22B-v0.1](https://modelscope.cn/models/swift/Codestral-22B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Codestral-22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)| +|[mistralai/Devstral-Small-2505](https://modelscope.cn/models/mistralai/Devstral-Small-2505)|devstral|devstral|transformers>=4.43, mistral-common>=1.5.5|✘|-|[mistralai/Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505)| |[modelscope/zephyr-7b-beta](https://modelscope.cn/models/modelscope/zephyr-7b-beta)|zephyr|zephyr|transformers>=4.34|✘|-|[HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)| |[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 7df8f01d99..027ae308c9 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -487,6 +487,7 @@ The table below introduces the models integrated with ms-swift: |[AI-ModelScope/Mistral-7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)| |[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf)|mistral|llama|transformers>=4.34|✘|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |[swift/Codestral-22B-v0.1](https://modelscope.cn/models/swift/Codestral-22B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Codestral-22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)| +|[mistralai/Devstral-Small-2505](https://modelscope.cn/models/mistralai/Devstral-Small-2505)|devstral|devstral|transformers>=4.43, mistral-common>=1.5.5|✘|-|[mistralai/Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505)| |[modelscope/zephyr-7b-beta](https://modelscope.cn/models/modelscope/zephyr-7b-beta)|zephyr|zephyr|transformers>=4.34|✘|-|[HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)| |[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| diff --git a/swift/llm/model/model/mistral.py b/swift/llm/model/model/mistral.py index 54b7884fe0..16c59aef5e 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/llm/model/model/mistral.py @@ -9,7 +9,7 @@ from ..model_arch import ModelArch from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, register_model) -from ..utils import ModelInfo +from ..utils import ModelInfo, safe_snapshot_download register_model( ModelMeta( @@ -148,7 +148,8 @@ def get_model_tokenizer_devstral_2505(model_dir: str, load_model: bool = True, **kwargs): # src: sglang did the same (https://github.com/sgl-project/sglang/pull/6547) - tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-Small-3.1-24B-Instruct-2503') + tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) kwargs['tokenizer'] = tokenizer model, processor = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) From bc97ddf97a25bf1d6a4e55d53801a9c79cb98f77 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 00:10:08 +0800 Subject: [PATCH 13/21] update --- swift/llm/model/constant.py | 1 + swift/llm/model/model/mistral.py | 4 ++-- swift/llm/template/constant.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index ba3e9daf11..1124fdb669 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -81,6 +81,7 @@ class LLMModelType: telechat2 = 'telechat2' mistral = 'mistral' + devstral = 'devstral' zephyr = 'zephyr' mixtral = 'mixtral' mistral_nemo = 'mistral_nemo' diff --git a/swift/llm/model/model/mistral.py b/swift/llm/model/model/mistral.py index 16c59aef5e..6d23b23b41 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/llm/model/model/mistral.py @@ -158,14 +158,14 @@ def get_model_tokenizer_devstral_2505(model_dir: str, register_model( ModelMeta( - model_type='devstral', + model_type=LLMModelType.devstral, model_groups=[ ModelGroup([ Model('mistralai/Devstral-Small-2505', 'mistralai/Devstral-Small-2505'), ], requires=['transformers>=4.43', 'mistral-common>=1.5.5']) ], - template='devstral', + template=TemplateType.devstral, get_function=get_model_tokenizer_devstral_2505, architectures=['MistralForCausalLM'], model_arch=ModelArch.llama)) diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index 12a8eca329..db7a423e04 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -71,6 +71,7 @@ class LLMTemplateType: mistral_nemo = 'mistral_nemo' mistral_2501 = 'mistral_2501' + devstral = 'devstral' zephyr = 'zephyr' wizardlm2 = 'wizardlm2' wizardlm2_moe = 'wizardlm2_moe' From 20b00dbefb00d8042bd4d59e87c0aa3e71f59ed5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 14:30:33 +0800 Subject: [PATCH 14/21] update --- swift/llm/dataset/utils.py | 4 ++-- swift/llm/template/base.py | 49 ++++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index 3c5de3343e..c71daad6e9 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -94,7 +94,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: self._idx = (self._idx + 1) % len(self.dataset) data = self.dataset[i] try: - return self.encode_func(data) + return self.encode_func(data, return_length=True) except Exception: if n_try == self.n_try_fetch - 1 or self.strict: if self.strict: @@ -146,7 +146,7 @@ def __init__( self.strict = strict self.load_from_cache_file = load_from_cache_file self.workers = [] - self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else None + self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else (None, None) if dist.is_initialized() and is_dist(): obj_list = [(self.packed_idx, self.packed_length)] dist.broadcast_object_list(obj_list) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 5ea56ca49f..576609c8df 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1153,6 +1153,26 @@ def _swift_encode(self, inputs: StdTemplateInputs): answer_len = 0 return res_context_list, loss_scale_list, answer_len + def _truncate(self, input_ids: List[int], labels: Optional[List[int]], loss_mask: Optional[List[float]], + truncation_strategy: Literal['left', 'right']): + placeholder_tokens = torch.tensor(self.placeholder_tokens) + input_ids_tensor = torch.tensor(input_ids) + protected = (input_ids_tensor[:, None] == placeholder_tokens).any(dim=-1) + n_protected = protected.sum().item() + if n_protected < self.max_length: + non_protected = (~protected).nonzero(as_tuple=True)[0] + if truncation_strategy == 'left': + idx = non_protected[-(self.max_length - n_protected):] + else: + idx = non_protected[:self.max_length - n_protected] + protected[idx] = True + input_ids = input_ids_tensor[protected].tolist() + if labels is not None: + labels = torch.tensor(labels)[protected].tolist() + if loss_mask is not None: + loss_mask = torch.tensor(loss_mask)[protected].tolist() + return input_ids, labels, loss_mask + def _encode_truncated(self, inputs: StdTemplateInputs): if inputs.is_multimodal: self._add_default_tags(inputs) @@ -1177,30 +1197,13 @@ def _encode_truncated(self, inputs: StdTemplateInputs): length = max(lengths) encoded['length'] = length - if self.max_length is not None: - if self.truncation_strategy == 'right': - input_ids = input_ids[:self.max_length] - if labels is not None: - labels = labels[:self.max_length] - if loss_scale is not None: - loss_scale = loss_scale[:self.max_length] - elif self.truncation_strategy == 'left': - if len(input_ids) > self.max_length: - logger.warning_once( - 'Input data was left-truncated because its length exceeds `max_length` (input length: ' - f'{len(input_ids)}, max_length: {self.max_length}). ' - 'This may cause loss of important tokens (e.g., image tokens) and lead to errors. ' - 'To avoid this, consider increasing `max_length` or pre-filtering long sequences.', - hash_id='max_length_check') - input_ids = input_ids[-self.max_length:] - if labels is not None: - labels = labels[-self.max_length:] - if loss_scale is not None: - loss_scale = loss_scale[-self.max_length:] + if self.max_length is not None and length > self.max_length: + if self.truncation_strategy in {'right', 'left'}: + input_ids, labels, loss_scale = self._truncate( + input_ids, labels, loss_scale, truncation_strategy=self.truncation_strategy) elif self.truncation_strategy == 'raise': - if length > self.max_length: - raise MaxLengthError(f'Current length of row({length}) is larger' - f' than the max_length({self.max_length}).') + raise MaxLengthError(f'Current length of row({length}) is larger' + f' than the max_length({self.max_length}).') encoded['input_ids'] = input_ids encoded['labels'] = labels encoded['loss_scale'] = loss_scale From aafb2e105efac85e85086d28e254fa1f3dc49d90 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 15:12:42 +0800 Subject: [PATCH 15/21] update --- .../infer/infer_engine/grpo_vllm_engine.py | 2 +- .../llm/infer/infer_engine/lmdeploy_engine.py | 2 +- swift/llm/infer/infer_engine/pt_engine.py | 2 +- swift/llm/infer/infer_engine/sglang_engine.py | 2 +- swift/llm/infer/infer_engine/vllm_engine.py | 2 +- swift/llm/template/base.py | 38 ++++++++----------- 6 files changed, 21 insertions(+), 27 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index f0e6c14ae8..44689b684f 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -367,7 +367,7 @@ def _create_chat_completion_response(self, result, template: Template, request_c else: choice_cls = ChatCompletionResponseChoice - token_ids = output.token_ids if request_config.return_details else None + token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None choice = choice_cls( index=output.index, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 8c81071352..d6145d1d1c 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -250,7 +250,7 @@ async def _infer_full_async( toolcall = self._get_toolcall(response, template) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, output.status.name == 'FINISH') - token_ids = output.token_ids if request_config.return_details else None + token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None choices = [ ChatCompletionResponseChoice( index=0, diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 9a349b7ddf..d0ac31209c 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -396,7 +396,7 @@ def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_ response = template.decode(generate_ids, template_inputs=template_inputs[i]) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True) toolcall = self._get_toolcall(response, template) - token_ids = generate_ids if request_config.return_details else None + token_ids = template.skip_stop_tokens(generate_ids) if request_config.return_details else None choices.append( ChatCompletionResponseChoice( index=j, diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index a19d5fb7ce..7d883c1c49 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -135,7 +135,7 @@ def _create_chat_completion_response(self, output, template, return_details: boo if template.template_meta.response_prefix: response = template.template_meta.response_prefix + response toolcall = self._get_toolcall(response, template) - token_ids = output['output_ids'] if return_details else None + token_ids = template.skip_stop_tokens(output['output_ids']) if return_details else None choice = ChatCompletionResponseChoice( index=0, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index b220b45929..74cb5711b2 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -415,7 +415,7 @@ def _create_chat_completion_response( response = template.decode(output.token_ids) logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(response, template) - token_ids = output.token_ids if request_config.return_details else None + token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 576609c8df..e71f23001d 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -571,17 +571,6 @@ def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs - @staticmethod - def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]: - len_tokens = len(stop_tokens) - if is_finished and generate_ids[-len_tokens:] == stop_tokens: - return generate_ids[:-len_tokens] - if not is_finished: - for i in range(len_tokens, 0, -1): - if generate_ids[-i:] == stop_tokens[:i]: - return generate_ids[:-i] - return generate_ids - @staticmethod def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int): idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist() @@ -644,7 +633,10 @@ def decode(self, first_token=True, **kwargs) -> Any: tokenizer_kwargs = tokenizer_kwargs or {} - response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs) + if 'spaces_between_special_tokens' not in tokenizer_kwargs: + tokenizer_kwargs['spaces_between_special_tokens'] = False + generate_ids = self.skip_stop_tokens(generate_ids, is_finished) + response = self.tokenizer.decode(generate_ids) if first_token and self.template_meta.response_prefix: response = self.template_meta.response_prefix + response return response @@ -674,7 +666,7 @@ def generate(self, model, *args, **kwargs): kwargs['use_model_defaults'] = False return model.generate(*args, **kwargs) - def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any: + def skip_stop_tokens(self, generate_ids: List[int], is_finished: bool = True) -> List[int]: # Do not print template_meta.suffix[-1] and eos_token. # However, other stop_words will be printed. tokenizer = self.tokenizer @@ -686,10 +678,16 @@ def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode if isinstance(template_suffix, str): # [-1:]: fix OpenGVLab/Mini-InternVL-Chat-4B-V1-5 template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:] - generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished) - if 'spaces_between_special_tokens' not in decode_kwargs: - decode_kwargs['spaces_between_special_tokens'] = False - return tokenizer.decode(generate_ids, **decode_kwargs) + + len_tokens = len(template_suffix) + if is_finished and generate_ids[-len_tokens:] == template_suffix: + generate_ids = generate_ids[:-len_tokens] + elif not is_finished: + for i in range(len_tokens, 0, -1): + if generate_ids[-i:] == template_suffix[:i]: + generate_ids = generate_ids[:-i] + break + return generate_ids def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]: generation_config = generate_kwargs['generation_config'] @@ -1567,10 +1565,6 @@ def _seq_cls_data_collator(self, res['labels'] = labels return res - def _data_flatten(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - new_batch = [(row, len(row['input_ids'])) for row in batch] - return [self.packing_row(new_batch)] - def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: """ Args: @@ -1582,7 +1576,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in padding_side = self.padding_side if self.is_training else 'left' padding_right = padding_side == 'right' if self.padding_free: - batch[:] = self._data_flatten(batch) + batch[:] = [self.packing_row(batch)] if self._packing: assert 'position_ids' in batch[0], f'batch[0]: {batch[0]}' res = {} From dec5060a1874152cd72e7746f4e48187d86f133d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 15:39:43 +0800 Subject: [PATCH 16/21] update --- swift/llm/infer/infer_engine/pt_engine.py | 27 +++--- swift/llm/infer/infer_engine/sglang_engine.py | 5 -- swift/llm/infer/infer_engine/vllm_engine.py | 5 -- swift/llm/template/base.py | 86 +++++++++---------- 4 files changed, 54 insertions(+), 69 deletions(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index d0ac31209c..3e234d580e 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -323,21 +323,27 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req elif 'last_hidden_state' in output: # embeddings logits = output['last_hidden_state'] - if template.mode == 'seq_cls': + if template.task_type == 'seq_cls': preds, logprobs = template.decode_seq_cls(logits, top_logprobs) - elif template.mode == 'prm': + elif template.task_type == 'prm': preds = template.decode_prm(inputs['input_ids'], logits) logprobs = [None] * len(preds) - elif template.mode == 'embedding': + elif template.task_type == 'embedding': preds = logits logprobs = [None] * len(preds) else: - raise ValueError(f'Unsupported mode: {template.mode}') + raise ValueError(f'Unsupported task_type: {template.task_type}') res = [] for i, pred in enumerate(preds): usage_info = self._get_usage_info(num_prompt_tokens, 1) - if template.mode != 'embedding': + if template.task_type == 'embedding': + res.append( + EmbeddingResponse( + model=self.model_name, + usage=usage_info, + data=[EmbeddingResponseData(embedding=pred.to(torch.float32).cpu().numpy().tolist())])) + else: choices = [ ChatCompletionResponseChoice( index=0, @@ -346,13 +352,6 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req logprobs=logprobs[i]) ] res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) - else: - res.append( - EmbeddingResponse( - model=self.model_name, - usage=usage_info, - data=[EmbeddingResponseData(embedding=pred.to(torch.float32).cpu().numpy().tolist())])) - return res def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_config: GenerationConfig, @@ -510,8 +509,8 @@ def _gen_wrapper(): return _gen_wrapper() else: if len(kwargs) > 0: - infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm', - 'embedding') else self._infer_full + infer_func = self._infer_forward if template.task_type in ('seq_cls', 'prm', + 'embedding') else self._infer_full res = infer_func(**kwargs) else: res = [] diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 7d883c1c49..4711d15351 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -173,11 +173,6 @@ async def infer_async(self, template = self.default_template template.set_mode('pt') - if self.task_type == 'embedding': - # TODO Refactor me - template.infer_backend = 'sglang' - template.task_type = self.task_type - template.set_mode('embedding') loop = asyncio.get_running_loop() with torch.inference_mode(): inputs = await loop.run_in_executor(None, template.encode, infer_request) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 74cb5711b2..12dcba1e00 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -550,11 +550,6 @@ async def infer_async( template = self.default_template template.set_mode('vllm') - if self.task_type == 'embed': - # TODO Refactor me - template.infer_backend = 'vllm' - template.task_type = 'embedding' - template.set_mode('embedding') loop = asyncio.get_running_loop() with torch.inference_mode(): inputs = await loop.run_in_executor(None, template.encode, infer_request) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index e71f23001d..302d92080b 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -119,8 +119,9 @@ def __init__( if self.is_encoder_decoder: self.skip_prompt = False self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer - 'train', 'rlhf', 'kto', 'gkd', # train - 'seq_cls', 'embedding', 'prm'] = 'pt' + 'train', 'rlhf', 'kto', 'gkd'] = 'pt' # train + self.task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'prm', 'reranker', + 'generative_reranker'] = 'causal_lm' self._packing = self.padding_free self.use_megatron = False self._handles = [] @@ -136,8 +137,7 @@ def init_processor(self, processor: Processor) -> None: self.processor = processor self.model_info = processor.model_info self.config = self.model_info.config - if self.model_info.task_type != 'causal_lm': - self.mode = self.model_info.task_type + self.task_type = self.model_info.task_type self.model_meta = processor.model_meta if self.max_length is None: @@ -428,15 +428,7 @@ def split_multi_medias(_inputs): anchor.messages[-1]['content'] = '' anchor.rejected_response = [] split_multi_medias(anchor) - infer_backend = getattr(self, 'infer_backend', 'pt') - mode = self.mode - # infer_backend comes from vllm-engine, sglang-engine, etc. - # TODO Refactor me - if infer_backend in ('vllm', 'sglang'): - self.mode = infer_backend _encoded = self._encode_truncated(anchor) - if infer_backend in ('vllm', 'sglang'): - self.mode = mode _encoded.pop('labels', None) return _encoded @@ -511,20 +503,24 @@ def encode(self, assert isinstance(inputs, StdTemplateInputs) self._preprocess_inputs(inputs) - if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}: - encoded = self._encode_truncated(inputs) - elif self.mode == 'seq_cls': + if self.task_type == 'causal_lm': + if self.mode in {'pt', 'train', 'vllm', 'lmdeploy'}: + encoded = self._encode_truncated(inputs) + elif self.mode == 'rlhf': + encoded = self._rlhf_encode(inputs) + elif self.mode == 'kto': + encoded = self._kto_encode(inputs) + elif self.mode == 'gkd': + encoded = self._gkd_encode(inputs) + elif self.task_type == 'seq_cls': encoded = self._seq_cls_encode(inputs) - elif self.mode == 'rlhf': - encoded = self._rlhf_encode(inputs) - elif self.mode == 'kto': - encoded = self._kto_encode(inputs) - elif self.mode == 'gkd': - encoded = self._gkd_encode(inputs) - elif self.mode == 'embedding': + elif self.task_type == 'prm': + encoded = self._encode_truncated(inputs) + elif self.task_type == 'embedding': encoded = self._embedding_encode(inputs) - elif self.mode in ['reranker', 'generative_reranker']: + elif self.task_type in {'reranker', 'generative_reranker'}: encoded = self._reranker_encode(inputs) + if inputs.channel is not None: encoded['channel'] = inputs.channel @@ -922,7 +918,7 @@ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float idx = inputs.bbox_idx c_list = self.replace_bbox(bbox[idx], idx, inputs) inputs.bbox_idx += 1 - elif context == '' and self.mode == 'prm': + elif context == '' and self.task_type == 'prm': c_list = self.replace_cot_process(inputs) else: c_list = [context] @@ -1124,7 +1120,7 @@ def _swift_encode(self, inputs: StdTemplateInputs): # self.is_training needed because we may want to continue generation from # the current response # TODO Refactor me - if self.is_training and not sep_token or (getattr(self, 'task_type', None) == 'embedding'): + if self.is_training and not sep_token or self.task_type == 'embedding': extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX elif template_meta.response_prefix: @@ -1210,7 +1206,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs): def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: template_backend = self.template_backend if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training - and self.mode != 'seq_cls'): + and self.task_type != 'seq_cls'): template_backend = 'jinja' logger.info_once(f'Setting template_backend: {template_backend}') res_context_list, loss_scale_list, answer_len = ( @@ -1337,10 +1333,7 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs): def is_training(self): return self.mode not in {'vllm', 'lmdeploy', 'sglang', 'pt'} - def set_mode( - self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto', 'gkd', 'embedding', 'reranker', - 'generative_reranker'] - ) -> None: + def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'train', 'rlhf', 'kto', 'gkd']) -> None: self.mode = mode def register_post_encode_hook(self, models: List[nn.Module]) -> None: @@ -1383,19 +1376,22 @@ def remove_post_encode_hook(self): def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: from swift.llm import RowPreprocessor - if self.mode == 'rlhf': - res = self._rlhf_data_collator(batch, padding_to=padding_to) - elif self.mode == 'kto': - res = self._kto_data_collator(batch, padding_to=padding_to) - elif self.mode == 'gkd': - res = self._gkd_data_collator(batch, padding_to=padding_to) - elif self.mode in {'pt', 'train', 'prm'}: + if self.task_type == 'causal_lm': + if self.mode in {'pt', 'train'}: + res = self._data_collator(batch, padding_to=padding_to) + elif self.mode == 'rlhf': + res = self._rlhf_data_collator(batch, padding_to=padding_to) + elif self.mode == 'kto': + res = self._kto_data_collator(batch, padding_to=padding_to) + elif self.mode == 'gkd': + res = self._gkd_data_collator(batch, padding_to=padding_to) + elif self.task_type == 'prm': res = self._data_collator(batch, padding_to=padding_to) - elif self.mode == 'seq_cls': + elif self.task_type == 'seq_cls': res = self._seq_cls_data_collator(batch, padding_to=padding_to) - elif self.mode == 'embedding': + elif self.task_type == 'embedding': res = self._embedding_data_collator(batch, padding_to=padding_to) - elif self.mode in ['reranker', 'generative_reranker']: + elif self.task_type in {'reranker', 'generative_reranker'}: res = self._reranker_data_collator(batch, padding_to=padding_to) if not self.remove_unused_columns: extra_kwargs = [b['_extra_kwargs'] for b in batch if b.get('_extra_kwargs') is not None] @@ -1714,11 +1710,11 @@ def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[s ] # For reranker/embedding modes, also check prefixed keys - if self.mode in {'reranker', 'generative_reranker', 'embedding'}: + if self.task_type in {'reranker', 'generative_reranker', 'embedding'}: prefixes = [] - if self.mode in {'reranker', 'generative_reranker'}: + if self.task_type in {'reranker', 'generative_reranker'}: prefixes = ['positive_', 'negative_'] - elif self.mode == 'embedding': + elif self.task_type == 'embedding': prefixes = ['anchor_', 'positive_', 'negative_'] # Add prefixed keys for reranker/embedding modes @@ -1746,7 +1742,7 @@ def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[s for key in keys_to_check: # Skip labels completely for certain modes - if key.endswith('labels') and self.mode in {'reranker', 'generative_reranker'}: + if key.endswith('labels') and self.task_type in {'reranker', 'generative_reranker'}: continue val = inputs.get(key) # fix val is a tensor @@ -1755,7 +1751,7 @@ def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[s if val is not None: key_upper = key.upper() logger.info(f'[{key_upper}_IDS] {val}') - if key.endswith('labels') and self.mode in {'seq_cls', 'embedding'}: + if key.endswith('labels') and self.task_type in {'seq_cls', 'embedding'}: continue if isinstance(val, (list, tuple, torch.Tensor)): # Handle nested lists (e.g., for reranker negative samples) From 6ad5f491b66327278a3cbaa403558a0eb10fdf14 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 15:54:15 +0800 Subject: [PATCH 17/21] update --- swift/llm/infer/infer_engine/pt_engine.py | 4 ++-- swift/llm/template/base.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 3e234d580e..c040ca6c6c 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -509,8 +509,8 @@ def _gen_wrapper(): return _gen_wrapper() else: if len(kwargs) > 0: - infer_func = self._infer_forward if template.task_type in ('seq_cls', 'prm', - 'embedding') else self._infer_full + infer_func = self._infer_forward if template.task_type in {'seq_cls', 'prm', + 'embedding'} else self._infer_full res = infer_func(**kwargs) else: res = [] diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 302d92080b..0a91ecda8f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -21,7 +21,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import strtobool -from swift.utils import get_dist_setting, get_env_args, get_logger +from swift.utils import get_env_args, get_logger from ..utils import Processor, ProcessorMixin from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by From 8e3a4032f3d3ad8fc4d232a23b32fc30aa8d592b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 16:19:30 +0800 Subject: [PATCH 18/21] update --- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 2 ++ swift/llm/infer/infer_engine/infer_engine.py | 1 + swift/llm/infer/infer_engine/pt_engine.py | 6 ++++-- swift/llm/infer/infer_engine/sglang_engine.py | 5 +++-- swift/llm/infer/infer_engine/vllm_engine.py | 12 +++++++----- swift/trainers/arguments.py | 2 +- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 44689b684f..12aac65057 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -47,6 +47,7 @@ def __init__( limit_mm_per_prompt: Optional[Dict[str, Any]] = None, device: str = 'auto', seed: Optional[int] = None, + task_type: Optional[str] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -78,6 +79,7 @@ def __init__( limit_mm_per_prompt=limit_mm_per_prompt, device=device, seed=seed, + task_type=task_type, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 165a5c5979..4e1c903d17 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -30,6 +30,7 @@ def _post_init(self, template=None): self.model_dir = self.model_info.model_dir self.model_name = self.model_info.model_name self.max_model_len = self.model_info.max_model_len + self.task_type = self.model_info.task_type self.config = self.model_info.config if template is None: ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index c040ca6c6c..b48dc8f69e 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -54,6 +54,7 @@ def __init__( # model kwargs attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None, device_map: Optional[Union[str, Dict[str, Any]]] = None, + task_type: Optional[str] = None, quantization_config=None, model_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, @@ -70,6 +71,7 @@ def __init__( device_map=device_map, quantization_config=quantization_config, attn_impl=attn_impl, + task_type=task_type, model_kwargs=model_kwargs, **kwargs) self.max_batch_size = max_batch_size @@ -509,8 +511,8 @@ def _gen_wrapper(): return _gen_wrapper() else: if len(kwargs) > 0: - infer_func = self._infer_forward if template.task_type in {'seq_cls', 'prm', - 'embedding'} else self._infer_full + infer_func = self._infer_forward if template.task_type in {'seq_cls', 'prm', 'embedding' + } else self._infer_full res = infer_func(**kwargs) else: res = [] diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 4711d15351..5e991177ac 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -43,13 +43,13 @@ def __init__( context_length: Optional[int] = None, disable_cuda_graph: bool = False, quantization: Optional[str] = None, + task_type: Optional[str] = None, kv_cache_dtype: str = 'auto', enable_dp_attention: bool = False, disable_custom_all_reduce: bool = True, log_level='error', engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, - task_type: Optional[str] = None, ): if engine_kwargs is None: engine_kwargs = {} @@ -61,7 +61,8 @@ def __init__( model_type=model_type, use_hf=use_hf, hub_token=hub_token, - revision=revision)[1] + revision=revision, + task_type=task_type)[1] self._post_init(template) if context_length is not None: self.max_model_len = context_length diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 12dcba1e00..33c5dff52a 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -59,6 +59,7 @@ def __init__( limit_mm_per_prompt: Optional[Dict[str, Any]] = None, device: str = 'auto', seed: Optional[int] = None, + task_type: Optional[str] = None, # embedding # lora enable_lora: bool = False, max_loras: int = 1, @@ -69,12 +70,10 @@ def __init__( quantization: Optional[str] = None, engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, - task_type: Optional[str] = None, ) -> None: if engine_kwargs is None: engine_kwargs = {} patch_vllm_memory_leak() - self.task_type = task_type self.use_async_engine = use_async_engine self.processor = get_model_tokenizer( model_id_or_path, @@ -84,7 +83,8 @@ def __init__( model_type=model_type, use_hf=use_hf, hub_token=hub_token, - revision=revision)[1] + revision=revision, + task_type=task_type)[1] self._post_init(template) self._prepare_engine_kwargs( @@ -147,6 +147,8 @@ def _prepare_engine_kwargs( task: Optional[str] = None, **engine_kwargs, ) -> None: + if task == 'embedding': + task = 'embed' disable_log_stats = engine_kwargs.pop('disable_log_stats', True) if self.use_async_engine: engine_cls = AsyncEngineArgs @@ -282,7 +284,7 @@ def _add_request(self, mm_data = {key.rstrip('s'): media_data[0]} if mm_data: llm_inputs['multi_modal_data'] = mm_data - if self.task_type == 'embed': + if self.task_type == 'embedding': from vllm.pooling_params import PoolingParams return self.engine.encode(llm_inputs, PoolingParams(), request_id) elif self.use_async_engine: @@ -440,7 +442,7 @@ async def _infer_full_async( result = None async for result in result_generator: pass - if self.task_type == 'embed': + if self.task_type == 'embedding': return self._create_embedding_response(result, template, generation_config, request_id) else: return self._create_chat_completion_response(result, template, request_config, request_id) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 189b1eca4a..6085dd4794 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -223,7 +223,7 @@ def get_vllm_engine_kwargs(self): 'quantization': self.vllm_quantization, } if self.task_type == 'embedding': - kwargs['task_type'] = 'embed' + kwargs['task_type'] = 'embedding' return kwargs From df2baf999a3f09a46ffcbfcea6c4f28605371174 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 16:24:16 +0800 Subject: [PATCH 19/21] update --- examples/train/moe/qwen2_5_moe.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/train/moe/qwen2_5_moe.sh b/examples/train/moe/qwen2_5_moe.sh index 9677a17cf8..1f1bdad1f1 100644 --- a/examples/train/moe/qwen2_5_moe.sh +++ b/examples/train/moe/qwen2_5_moe.sh @@ -1,7 +1,7 @@ # Manually select `target_modules` to avoid 'all-linear' selecting 'gate' CUDA_VISIBLE_DEVICES=0,1 \ swift sft \ - --model Qwen/Qwen2-57B-A14B-Instruct \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ --train_type lora \ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ 'AI-ModelScope/alpaca-gpt4-data-en#500' \ From 7148562cb95acb94bef407b784ce4dadf731b3e7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 17:15:53 +0800 Subject: [PATCH 20/21] fix --- swift/llm/infer/infer_engine/sglang_engine.py | 2 +- swift/llm/template/base.py | 11 +++++------ swift/llm/train/sft.py | 3 +-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 5e991177ac..7b98686002 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -173,7 +173,7 @@ async def infer_async(self, if template is None: template = self.default_template - template.set_mode('pt') + template.set_mode('sglang') loop = asyncio.get_running_loop() with torch.inference_mode(): inputs = await loop.run_in_executor(None, template.encode, infer_request) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 0a91ecda8f..f6cd5c2073 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -118,7 +118,7 @@ def __init__( self.norm_bbox = norm_bbox or self.norm_bbox if self.is_encoder_decoder: self.skip_prompt = False - self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer + self.mode: Literal['pt', 'vllm', 'lmdeploy', 'sglang', # infer 'train', 'rlhf', 'kto', 'gkd'] = 'pt' # train self.task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'prm', 'reranker', 'generative_reranker'] = 'causal_lm' @@ -495,7 +495,7 @@ def encode(self, extra_kwargs = {} if isinstance(inputs, dict): inputs = deepcopy(inputs) - if not self.is_training: + if self.task_type == 'causal_lm' and not self.is_training: InferRequest.remove_response(inputs['messages']) inputs, extra_kwargs = StdTemplateInputs.from_dict(inputs) elif isinstance(inputs, StdTemplateInputs): @@ -504,7 +504,7 @@ def encode(self, self._preprocess_inputs(inputs) if self.task_type == 'causal_lm': - if self.mode in {'pt', 'train', 'vllm', 'lmdeploy'}: + if self.mode in {'train', 'pt', 'vllm', 'lmdeploy', 'sglang'}: encoded = self._encode_truncated(inputs) elif self.mode == 'rlhf': encoded = self._rlhf_encode(inputs) @@ -1119,7 +1119,6 @@ def _swift_encode(self, inputs: StdTemplateInputs): context_list.append('{{RESPONSE}}') # self.is_training needed because we may want to continue generation from # the current response - # TODO Refactor me if self.is_training and not sep_token or self.task_type == 'embedding': extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX @@ -1331,9 +1330,9 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs): @property def is_training(self): - return self.mode not in {'vllm', 'lmdeploy', 'sglang', 'pt'} + return self.mode not in {'pt', 'vllm', 'lmdeploy', 'sglang'} - def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'train', 'rlhf', 'kto', 'gkd']) -> None: + def set_mode(self, mode: Literal['pt', 'vllm', 'lmdeploy', 'sglang', 'train', 'rlhf', 'kto', 'gkd']) -> None: self.mode = mode def register_post_encode_hook(self, models: List[nn.Module]) -> None: diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index cd0aef5eb0..0d9e043338 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -54,8 +54,7 @@ def _prepare_model_tokenizer(self, load_model=True): def _prepare_template(self) -> None: template = self.args.get_template(self.processor) - if self.args.task_type == 'causal_lm': - template.set_mode('train') + template.set_mode('train') if template.use_model: template.model = self.model self.template = template From d5f6d0b7df26c2d176c6f7a2cac26a254b6bc4dd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 30 Jul 2025 18:04:27 +0800 Subject: [PATCH 21/21] update --- ...Megatron-SWIFT\350\256\255\347\273\203.md" | 3 +- ...44\350\241\214\345\217\202\346\225\260.md" | 7 ++- .../Instruction/Command-line-parameters.md | 7 ++- .../Instruction/Megatron-SWIFT-Training.md | 3 +- examples/export/cached_dataset/mcore.sh | 48 +++++++++++++++++++ .../reranker/train_generative_reranker.sh | 1 + .../train_generative_reranker_listwise.sh | 1 + examples/train/reranker/train_reranker.sh | 1 + .../train/reranker/train_reranker_listwise.sh | 1 + swift/llm/argument/base_args/base_args.py | 9 ++-- swift/llm/argument/export_args.py | 6 +++ 11 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 examples/export/cached_dataset/mcore.sh diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 25c0a05359..d652035c1f 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -437,8 +437,9 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内 - 🔥packing: 是否使用序列packing,默认为False。当前支持`megatron pt/sft`。 - packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache `参数来实现这一要求。 - 注意:该参数将在"ms-swift>=3.7"被移除。多机packing不再需要设置packing_cache。 -- 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。 +- streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。 - lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。 +- 🔥cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize占用gpu时。默认为`[]`。 - max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。 - 注意:如果你使用非流式数据集,该参数会为你自动计算train_iters,你不需要手动传入`train_iters`。 diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 9258a7d24e..07e218f44f 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -58,7 +58,7 @@ - dataset_shuffle: 是否对dataset进行随机操作。默认为True。 - 注意:CPT/SFT的随机包括两个部分:数据集的随机,由`dataset_shuffle`控制;train_dataloader中的随机,由`train_dataloader_shuffle`控制。 - val_dataset_shuffle: 是否对val_dataset进行随机操作。默认为False。 -- 🔥streaming: 流式读取并处理数据集,默认False。 +- streaming: 流式读取并处理数据集,默认False。 - 注意:需要额外设置`--max_steps`,因为流式数据集无法获得其长度。你可以通过设置`--save_strategy epoch`并设置较大的max_steps来实现与`--num_train_epochs`等效的训练。或者,你也可以设置`max_epochs`确保训练到对应epochs时退出训练,并对权重进行验证和保存。 - 注意:流式数据集可以跳过预处理等待,将预处理时间与训练时间重叠。流式数据集的预处理只在rank0上进行,并通过数据分发的方式同步到其他进程,其通常效率不如非流式数据集采用的数据分片读取方式。当训练的world_size较大时,预处理和数据分发将成为训练瓶颈。 - interleave_prob: 默认值为 None。在组合多个数据集时,默认使用 `concatenate_datasets` 函数;如果设置了该参数,则会使用 `interleave_datasets` 函数。该参数通常用于流式数据集的组合,并会作为参数传入 `interleave_datasets` 函数中。 @@ -393,7 +393,8 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh。注意:请使用"ms-swift>=3.6",关注[此PR](https://github.com/modelscope/ms-swift/pull/4838)。 - packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache `参数来实现这一要求。 - 注意:该参数将在"ms-swift>=3.7"被移除。多机packing不再需要设置packing_cache。 -- 🔥lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数在LLM训练中默认设置为False,而MLLM训练默认为True,节约内存。 +- lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数在LLM训练中默认设置为False,而MLLM训练默认为True,节约内存。 +- 🔥cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize占用gpu时。默认为`[]`。 - use_logits_to_keep: 通过在`forward`中根据labels传入logits_to_keep,减少无效logits的计算与存储,从而减少显存占用并加快训练速度。默认为None,进行自动选择。 - 注意:为了稳定性,多模态模型该值默认为False,需要手动设置。 - acc_strategy: 训练和验证时计算acc的策略。可选为`seq`和`token`级别的acc,默认为`token`。 @@ -614,6 +615,8 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - max_length: 校准集的max_length, 默认值2048。 - quant_batch_size: 量化batch_size,默认为1。 - group_size: 量化group大小,默认为128。 +- to_cached_dataset: 提前对数据集进行tokenize并导出,默认为False。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 + - 注意:数据packing在训练时进行,而不在此处。 - to_ollama: 产生ollama所需的Modelfile文件。默认为False。 - 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。 - to_hf: Megatron格式权重转成HF格式。默认为False。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 9ffc61aacb..c349734af3 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -58,7 +58,7 @@ Hints: - dataset_shuffle: Whether to shuffle the dataset. Defaults to True. - Note: The shuffling in CPT/SFT consists of two parts: dataset shuffling, controlled by `dataset_shuffle`; and shuffling in the train_dataloader, controlled by `train_dataloader_shuffle`. - val_dataset_shuffle: Whether to perform shuffling on the val_dataset. Default is False. -- 🔥streaming: Stream reading and processing of the dataset, default is False. +- streaming: Stream reading and processing of the dataset, default is False. - Note: You need to set `--max_steps` explicitly, as the streaming dataset does not have a defined length. You can achieve training equivalent to `--num_train_epochs` by setting `--save_strategy epoch` and specifying a sufficiently large `max_steps`. Alternatively, you can set `max_epochs` to ensure training exits after the corresponding number of epochs, at which point the model weights will be validated and saved. - Note: Streaming datasets can skip preprocessing wait time by overlapping preprocessing with training. Preprocessing for streaming datasets is performed only on rank 0 and then synchronized to other processes via data distribution. This approach is generally less efficient than the data sharding and reading method used by non-streaming datasets. When the world size is large, preprocessing and data distribution can become a training bottleneck. - interleave_prob: Defaults to None. When combining multiple datasets, the `concatenate_datasets` function is used by default. If this parameter is set, the `interleave_datasets` function will be used instead. This parameter is typically used when combining streaming datasets and is passed to the `interleave_datasets` function. @@ -402,7 +402,8 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh. Note: Please use "ms-swift>=3.6" and follow [this PR](https://github.com/modelscope/ms-swift/pull/4838). - packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache ` argument in the command line. - Note: This parameter will be removed in "ms-swift>=3.7". The `packing_cache` setting will no longer be required for multi-node packing. -- 🔥lazy_tokenize: Whether to use lazy tokenization. If set to False, all dataset samples are tokenized before training (for multimodal models, this includes reading images from disk). This parameter defaults to False for LLM training, and True for MLLM training, to save memory. +- lazy_tokenize: Whether to use lazy tokenization. If set to False, all dataset samples are tokenized before training (for multimodal models, this includes reading images from disk). This parameter defaults to False for LLM training, and True for MLLM training, to save memory. +- 🔥cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default: `[]`. - use_logits_to_keep: Pass `logits_to_keep` in the `forward` method based on labels to reduce the computation and storage of unnecessary logits, thereby reducing memory usage and accelerating training. The default is `None`, which enables automatic selection. - Note: For stability, this value is set to False by default for multimodal models and needs to be manually enabled. - acc_strategy: Strategy for calculating accuracy during training and validation. Options are `seq`-level and `token`-level accuracy, with `token` as the default. @@ -633,6 +634,8 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - max_length: Max length for the calibration set, default value is 2048. - quant_batch_size: Quantization batch size, default is 1. - group_size: Group size for quantization, default is 128. +- to_cached_dataset: pre-tokenize the dataset and export it in advance, default is False. See the example [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). + - Note: data packing is performed during training, not in this step. - to_ollama: Generate the Modelfile required by Ollama. Default is False. - 🔥to_mcore: Convert weights from HF format to Megatron format. Default is False. - to_hf: Convert weights from Megatron format to HF format. Default is False. diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 58d7b73b5b..cfdc10f53b 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -456,8 +456,9 @@ Megatron training parameters inherit from Megatron parameters and basic paramete - 🔥packing: Whether to use sequence packing, defaults to False. Currently supports `megatron pt/sft`. - packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache ` argument in the command line. - Note: This parameter will be removed in "ms-swift>=3.7". The `packing_cache` setting will no longer be required for multi-node packing. -- 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation. +- streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation. - lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory). +- 🔥cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default: `[]`. - max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None. - Note: If you use a non-streaming dataset, this parameter will automatically calculate train_iters for you, so there is no need to pass `train_iters` manually. diff --git a/examples/export/cached_dataset/mcore.sh b/examples/export/cached_dataset/mcore.sh new file mode 100644 index 0000000000..ef59ead23d --- /dev/null +++ b/examples/export/cached_dataset/mcore.sh @@ -0,0 +1,48 @@ +# +swift export \ + --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \ + --max_length 8192 \ + --dataset_num_proc 64 \ + --to_cached_dataset true \ + --output_dir ./cached_dataset + + +# 4 * 95GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +megatron sft \ + --load Qwen3-30B-A3B-Base-mcore \ + --cached_dataset './cached_dataset' \ + --train_type lora \ + --lora_rank 32 \ + --lora_alpha 64 \ + --target_modules all-linears \ + --split_dataset_ratio 0.01 \ + --moe_permute_fusion true \ + --expert_model_parallel_size 4 \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-3 \ + --micro_batch_size 1 \ + --global_batch_size 16 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 2 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --save megatron_output/Qwen3-30B-A3B-Base \ + --eval_interval 200 \ + --save_interval 200 \ + --packing true \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash diff --git a/examples/train/reranker/train_generative_reranker.sh b/examples/train/reranker/train_generative_reranker.sh index cf946aa61e..e08be2472f 100644 --- a/examples/train/reranker/train_generative_reranker.sh +++ b/examples/train/reranker/train_generative_reranker.sh @@ -17,6 +17,7 @@ swift sft \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 8 \ + --dataset_num_proc 8 \ --learning_rate 6e-6 \ --label_names labels \ --dataloader_drop_last true diff --git a/examples/train/reranker/train_generative_reranker_listwise.sh b/examples/train/reranker/train_generative_reranker_listwise.sh index fd9acbdbe0..266369b7e9 100644 --- a/examples/train/reranker/train_generative_reranker_listwise.sh +++ b/examples/train/reranker/train_generative_reranker_listwise.sh @@ -17,6 +17,7 @@ swift sft \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 8 \ + --dataset_num_proc 8 \ --learning_rate 6e-6 \ --label_names labels \ --dataloader_drop_last true diff --git a/examples/train/reranker/train_reranker.sh b/examples/train/reranker/train_reranker.sh index 50992cc8bb..2bc8938608 100644 --- a/examples/train/reranker/train_reranker.sh +++ b/examples/train/reranker/train_reranker.sh @@ -15,6 +15,7 @@ swift sft \ --per_device_train_batch_size 64 \ --per_device_eval_batch_size 64 \ --gradient_accumulation_steps 1 \ + --dataset_num_proc 8 \ --learning_rate 6e-6 \ --label_names labels \ --dataloader_drop_last true \ diff --git a/examples/train/reranker/train_reranker_listwise.sh b/examples/train/reranker/train_reranker_listwise.sh index b08fd8d3c9..13ffa32928 100644 --- a/examples/train/reranker/train_reranker_listwise.sh +++ b/examples/train/reranker/train_reranker_listwise.sh @@ -15,6 +15,7 @@ swift sft \ --per_device_train_batch_size 64 \ --per_device_eval_batch_size 64 \ --gradient_accumulation_steps 1 \ + --dataset_num_proc 8 \ --learning_rate 6e-6 \ --label_names labels \ --dataloader_drop_last true \ diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 4577278c41..17d2295c74 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -102,16 +102,17 @@ def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: pass def _init_lazy_tokenize(self): - if self.streaming and self.lazy_tokenize: - self.lazy_tokenize = False - logger.warning('Streaming and lazy_tokenize are incompatible. ' - f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') if self.lazy_tokenize is None: if self.model_meta.is_multimodal and not self.streaming and not self.packing: self.lazy_tokenize = True else: self.lazy_tokenize = False logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') + if self.lazy_tokenize: + if self.packing: + raise ValueError('Packing and lazy_tokenize are incompatible.') + if self.streaming: + raise ValueError('Streaming and lazy_tokenize are incompatible.') def _init_custom_register(self) -> None: """Register custom .py file to datasets""" diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 1b6539e57c..7bac558d5b 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -95,6 +95,12 @@ def _init_output_dir(self): logger.info(f'args.output_dir: `{self.output_dir}`') def __post_init__(self): + if self.to_cached_dataset: + if self.packing: + raise ValueError( + 'Packing will be handled during training; here we only perform tokenization ' + 'in advance, so you do not need to set up packing separately.') + assert not self.streaming and not self.lazy_tokenize, 'not supported' if self.quant_batch_size == -1: self.quant_batch_size = None if isinstance(self.mcore_adapters, str):