diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 3633821800..0e2ef01d0b 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -487,6 +487,7 @@ |[AI-ModelScope/Mistral-7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)| |[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf)|mistral|llama|transformers>=4.34|✘|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |[swift/Codestral-22B-v0.1](https://modelscope.cn/models/swift/Codestral-22B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Codestral-22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)| +|[mistralai/Devstral-Small-2505](https://modelscope.cn/models/mistralai/Devstral-Small-2505)|devstral|devstral|transformers>=4.43, mistral-common>=1.5.5|✘|-|[mistralai/Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505)| |[modelscope/zephyr-7b-beta](https://modelscope.cn/models/modelscope/zephyr-7b-beta)|zephyr|zephyr|transformers>=4.34|✘|-|[HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)| |[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 7df8f01d99..027ae308c9 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -487,6 +487,7 @@ The table below introduces the models integrated with ms-swift: |[AI-ModelScope/Mistral-7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)| |[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf)|mistral|llama|transformers>=4.34|✘|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |[swift/Codestral-22B-v0.1](https://modelscope.cn/models/swift/Codestral-22B-v0.1)|mistral|llama|transformers>=4.34|✘|-|[mistralai/Codestral-22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)| +|[mistralai/Devstral-Small-2505](https://modelscope.cn/models/mistralai/Devstral-Small-2505)|devstral|devstral|transformers>=4.43, mistral-common>=1.5.5|✘|-|[mistralai/Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505)| |[modelscope/zephyr-7b-beta](https://modelscope.cn/models/modelscope/zephyr-7b-beta)|zephyr|zephyr|transformers>=4.34|✘|-|[HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)| |[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1)|mixtral|llama|transformers>=4.36|✘|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 8f31fd80da..4577278c41 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -63,6 +63,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat seed (int): Random seed for reproducibility. Default is 42. model_kwargs (Optional[str]): Additional keyword arguments for the model. Default is None. load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False. + packing (bool): Flag to enable packing of datasets. Default is False. + lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. use_hf (bool): Flag to determine if Hugging Face should be used. Default is False. hub_token (Optional[str]): SDK token for authentication. Default is None. custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None. @@ -80,6 +82,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat load_data_args: bool = False # dataset packing: bool = False + lazy_tokenize: Optional[bool] = None + cached_dataset: List[str] = field(default_factory=list) custom_register_path: List[str] = field(default_factory=list) # .py # hub use_hf: bool = False @@ -97,6 +101,18 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: pass + def _init_lazy_tokenize(self): + if self.streaming and self.lazy_tokenize: + self.lazy_tokenize = False + logger.warning('Streaming and lazy_tokenize are incompatible. ' + f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') + if self.lazy_tokenize is None: + if self.model_meta.is_multimodal and not self.streaming and not self.packing: + self.lazy_tokenize = True + else: + self.lazy_tokenize = False + logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') + def _init_custom_register(self) -> None: """Register custom .py file to datasets""" if isinstance(self.custom_register_path, str): @@ -154,7 +170,9 @@ def __post_init__(self): QuantizeArguments.__post_init__(self) TemplateArguments.__post_init__(self) DataArguments.__post_init__(self) - + if isinstance(self.cached_dataset, str): + self.cached_dataset = [self.cached_dataset] + self._init_lazy_tokenize() self.hub = get_hub(self.use_hf) if self.hub.try_login(self.hub_token): logger.info('hub login successful!') diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 370ffbc777..1b6539e57c 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -40,6 +40,9 @@ class ExportArguments(MergeArguments, BaseArguments): quant_batch_size: int = 1 group_size: int = 128 + # cached_dataset + to_cached_dataset: bool = False + # ollama to_ollama: bool = False @@ -79,6 +82,8 @@ def _init_output_dir(self): suffix = 'mcore' elif self.to_hf: suffix = 'hf' + elif self.to_cached_dataset: + suffix = 'cached_dataset' else: return diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 872eb6d57b..33629cb13b 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -109,8 +109,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra Args: add_version (bool): Flag to add version information to output_dir. Default is True. loss_type (Optional[str]): Type of loss function to use. Default is None. - packing (bool): Flag to enable packing of datasets. Default is False. - lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. max_new_tokens (int): Maximum number of new tokens to generate. Default is 64. temperature (float): Temperature for sampling. Default is 0. optimizer (Optional[str]): Optimizer type to use, define it in the plugin package. Default is None. @@ -118,7 +116,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra """ add_version: bool = True create_checkpoint_symlink: bool = False - lazy_tokenize: Optional[bool] = None # plugin loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) @@ -135,15 +132,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra # auto_tp deepspeed_autotp_size: Optional[int] = None - def _init_lazy_tokenize(self): - if self.streaming and self.lazy_tokenize: - self.lazy_tokenize = False - logger.warning('Streaming and lazy_tokenize are incompatible. ' - f'Setting args.lazy_tokenize: {self.lazy_tokenize}.') - if self.lazy_tokenize is None: - self.lazy_tokenize = self.model_meta.is_multimodal and not self.streaming - logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') - def __post_init__(self) -> None: if self.padding_free or self.packing: if self.packing: @@ -179,7 +167,6 @@ def __post_init__(self) -> None: self._init_deepspeed() self._init_device() - self._init_lazy_tokenize() if getattr(self, 'accelerator_config', None) is None: self.accelerator_config = {'dispatch_batches': False} diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index ae03a60659..3c5de3343e 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -146,23 +146,18 @@ def __init__( self.strict = strict self.load_from_cache_file = load_from_cache_file self.workers = [] - preprocessor = EncodePreprocessor(template=template) - self.dataset = preprocessor( - dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) - if template.model_meta.is_multimodal: - self.dataset = LazyLLMDataset(self.dataset, encode_func=template.encode) - self.packed_idx = self.create_packed_idx() if is_master() else None + self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else None if dist.is_initialized() and is_dist(): - obj_list = [self.packed_idx] + obj_list = [(self.packed_idx, self.packed_length)] dist.broadcast_object_list(obj_list) - self.packed_idx = obj_list[0] + self.packed_idx, self.packed_length = obj_list[0] def create_packed_idx(self): lengths = self.dataset['length'] data = [(i, length) for i, length in enumerate(lengths)] i = 0 PACKING_BATCH_SIZE = 1000 - input_data, res = [], [] + input_data, packed_idx, packed_length = [], [], [] with tqdm(total=len(data), dynamic_ncols=True, desc='Packing: ') as prog_bar: while True: new_data = data[i:i + PACKING_BATCH_SIZE] @@ -173,14 +168,13 @@ def create_packed_idx(self): i += PACKING_BATCH_SIZE is_finished = i >= len(data) sequences, input_data = calculate_matched_group(self.template, input_data, is_finished=is_finished) - res += sequences - return res + packed_idx += [[x[0] for x in seq] for seq in sequences] + packed_length += [sum(x[1] for x in seq) for seq in sequences] + return packed_idx, packed_length def __getitem__(self, index): sequence = self.packed_idx[index] - row = [] - for i, length in sequence: - row.append((self.dataset[i], length)) + row = [self.dataset[i] for i in sequence] return self.template.packing_row(row) def __len__(self): @@ -221,7 +215,7 @@ def _processor(self): i, data = self._in_queue.get() encoded_data = {} try: - encoded_data = self.template.encode(data) + encoded_data = self.template.encode(data, return_length=True) except Exception as e: if self.strict and not isinstance(e, MaxLengthError): raise @@ -271,7 +265,7 @@ def __iter__(self): sequences, data = calculate_matched_group(self.template, data, is_finished=finished) res = [] for row in sequences: - packed = self.template.packing_row(row) + packed = self.template.packing_row([r[0] for r in row]) res.append(packed) yield from res if finished: @@ -286,7 +280,7 @@ def __init__(self, template: 'Template'): self.is_multimodal = template.model_meta.is_multimodal def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: - encoded = self.template.encode(row) + encoded = self.template.encode(row, return_length=True) if self.is_multimodal: row['length'] = encoded['length'] encoded = row diff --git a/swift/llm/export/__init__.py b/swift/llm/export/__init__.py index e3330166f1..6012848527 100644 --- a/swift/llm/export/__init__.py +++ b/swift/llm/export/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .cached_dataset import export_cached_dataset from .export import SwiftExport, export_main from .merge_lora import merge_lora from .ollama import export_to_ollama diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py new file mode 100644 index 0000000000..5155b83bbf --- /dev/null +++ b/swift/llm/export/cached_dataset.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List, Union + +from swift.llm import ExportArguments +from swift.llm.train import SwiftSft +from swift.utils import get_logger + +logger = get_logger() + + +class ExportCachedDataset(SwiftSft): + args_class = ExportArguments + args: args_class + + def __init__(self, args: Union[List[str], ExportArguments, None] = None) -> None: + super(SwiftSft, self).__init__(args) + self.train_msg = {} # dummy + self.processor = None + self._prepare_template() + self._prepare_model_tokenizer(load_model=self.template.use_model) + self.template.init_processor(self.processor) + + def main(self): + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + self._show_dataset(train_dataset, val_dataset) + train_dataset.save_to_disk(os.path.join(self.args.output_dir, 'train')) + if val_dataset is not None: + val_dataset.save_to_disk(os.path.join(self.args.output_dir, 'val')) + logger.info(f'Dataset saved to `{self.args.output_dir}`') + + +def export_cached_dataset(args: Union[List[str], ExportArguments, None] = None): + return ExportCachedDataset(args).main() diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index d78658ba62..d82d755d92 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -4,6 +4,7 @@ from swift.llm import ExportArguments, SwiftPipeline from swift.tuners import swift_to_peft_format from swift.utils import get_logger +from .cached_dataset import export_cached_dataset from .merge_lora import merge_lora from .ollama import export_to_ollama from .quant import quantize_model @@ -29,6 +30,8 @@ def run(self): quantize_model(args) elif args.to_ollama: export_to_ollama(args) + elif args.to_cached_dataset: + export_cached_dataset(args) elif args.to_mcore: from swift.megatron import convert_hf2mcore convert_hf2mcore(args) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 5a859d8e14..e0eca6a5ca 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -250,7 +250,8 @@ def infer_dataset(self) -> List[Dict[str, Any]]: prog_bar.close() metrics = self.infer_kwargs.pop('metrics') if result_list: - print(f'[rank{args.rank}] {metrics[0].compute()}') + metric = metrics[0].compute() + print(f'[rank{args.rank}] {metric}' if args.rank >= 0 else str(metric)) if args.metric is not None: self._calc_metric() return result_list diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index ba3e9daf11..1124fdb669 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -81,6 +81,7 @@ class LLMModelType: telechat2 = 'telechat2' mistral = 'mistral' + devstral = 'devstral' zephyr = 'zephyr' mixtral = 'mixtral' mistral_nemo = 'mistral_nemo' diff --git a/swift/llm/model/model/mistral.py b/swift/llm/model/model/mistral.py index 54b7884fe0..6d23b23b41 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/llm/model/model/mistral.py @@ -9,7 +9,7 @@ from ..model_arch import ModelArch from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, register_model) -from ..utils import ModelInfo +from ..utils import ModelInfo, safe_snapshot_download register_model( ModelMeta( @@ -148,7 +148,8 @@ def get_model_tokenizer_devstral_2505(model_dir: str, load_model: bool = True, **kwargs): # src: sglang did the same (https://github.com/sgl-project/sglang/pull/6547) - tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-Small-3.1-24B-Instruct-2503') + tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) kwargs['tokenizer'] = tokenizer model, processor = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) @@ -157,14 +158,14 @@ def get_model_tokenizer_devstral_2505(model_dir: str, register_model( ModelMeta( - model_type='devstral', + model_type=LLMModelType.devstral, model_groups=[ ModelGroup([ Model('mistralai/Devstral-Small-2505', 'mistralai/Devstral-Small-2505'), ], requires=['transformers>=4.43', 'mistral-common>=1.5.5']) ], - template='devstral', + template=TemplateType.devstral, get_function=get_model_tokenizer_devstral_2505, architectures=['MistralForCausalLM'], model_arch=ModelArch.llama)) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index ee504c72e5..5ea56ca49f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -86,7 +86,7 @@ def __init__( from .template_meta import TemplateMeta from swift.plugin import agent_templates, loss_scale_map self._processor_inited = False - self._version = 'v1' # Avoid compatibility issues caused by load_from_cache_file caching. + self._version = 'v2' # Avoid compatibility issues caused by load_from_cache_file caching. self.max_length = max_length self.model = None @@ -488,7 +488,8 @@ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @torch.inference_mode() def encode(self, inputs: Union[TemplateInputs, Dict[str, Any], InferRequest], - return_template_inputs: bool = False) -> Dict[str, Any]: + return_template_inputs: bool = False, + return_length: bool = False) -> Dict[str, Any]: """The entrance method of Template! Returns: @@ -537,7 +538,7 @@ def encode(self, lengths.append(value) elif isinstance(value, (tuple, list)): lengths += value - if self.is_training: + if return_length: encoded['length'] = max(lengths) else: encoded.pop('length', None) @@ -547,22 +548,24 @@ def encode(self, encoded['_extra_kwargs'] = extra_kwargs return encoded - def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: packed = {} keys = set() + length = [] for r in row: - keys.update(r[0].keys()) + keys.update(r.keys()) + length.append(r['length']) for key in keys: if key in {'input_ids', 'labels', 'loss_scale'}: - packed[key] = sum((x[0][key] for x in row), start=[]) + packed[key] = sum((x[key] for x in row), start=[]) elif key == 'length': - packed[key] = sum((x[0][key] for x in row)) + packed[key] = sum((x[key] for x in row)) elif key == 'channel': - packed[key] = [x[0][key] for x in row] + packed[key] = [x[key] for x in row] if 'position_ids' not in packed: - packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[]) + packed['position_ids'] = sum((list(range(x)) for x in length), start=[]) - packed.update(self._data_collator_mm_data([r[0] for r in row])) + packed.update(self._data_collator_mm_data(row)) return packed def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index 12a8eca329..db7a423e04 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -71,6 +71,7 @@ class LLMTemplateType: mistral_nemo = 'mistral_nemo' mistral_2501 = 'mistral_2501' + devstral = 'devstral' zephyr = 'zephyr' wizardlm2 = 'wizardlm2' wizardlm2_moe = 'wizardlm2_moe' diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index e3a3571016..4c4a761851 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -381,10 +381,10 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: res[f'{media_type}_grid_thw'] = grid_thw return res - def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: position_ids = [] for r in row: - r = r[0].copy() + r = r.copy() r['input_ids'] = torch.tensor(r['input_ids'])[None] position_ids.append(self._get_position_ids(r)) packed = super().packing_row(row) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index f0fa774527..cd0aef5eb0 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -4,7 +4,9 @@ from typing import List, Union from datasets import Dataset as HfDataset +from datasets import load_from_disk +from swift.llm.dataset.loader import DatasetLoader from swift.plugin import extra_callbacks, get_loss_func, get_metric from swift.trainers import TrainerFactory from swift.utils import append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array @@ -35,13 +37,14 @@ def _prepare_generation_config(self): args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') - def _prepare_model_tokenizer(self): + def _prepare_model_tokenizer(self, load_model=True): args = self.args if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.init_sequence_parallel(args.sequence_parallel_size) - self.model, self.processor = args.get_model_processor() - + self.model, self.processor = args.get_model_processor(load_model=load_model) + if self.model is None: + return if hasattr(self.model, 'hf_device_map'): logger.info(f'model.hf_device_map: {self.model.hf_device_map}') @@ -88,11 +91,60 @@ def _save_val_dataset(self, val_dataset): append_to_jsonl(val_dataset_path, val_dataset.to_list()) logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') - def run(self): + def _get_cached_dataset(self): + args = self.args + assert not args.streaming and not args.lazy_tokenize + train_datasets, val_datasets = [], [] + for cached_dataset in args.cached_dataset: + train_path = os.path.join(cached_dataset, 'train') + val_path = os.path.join(cached_dataset, 'val') + train_datasets.append(load_from_disk(train_path)) + if os.path.exists(val_path): + val_datasets.append(load_from_disk(val_path)) + return train_datasets, val_datasets + + def _prepare_dataset(self): args = self.args + if args.cached_dataset: + train_datasets, val_datasets = self._get_cached_dataset() + else: + train_datasets, val_datasets = [], [] + if args.dataset: + train_dataset, val_dataset = self._get_dataset() + train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + train_datasets.append(train_dataset) + val_datasets.append(val_dataset) + train_dataset = DatasetLoader._concat_datasets(train_datasets) + val_dataset = DatasetLoader._concat_datasets(val_datasets) + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + predict_with_generate = getattr(args, 'predict_with_generate', False) + datasets = [train_dataset, val_dataset] + if is_grpo: + return datasets + template = self.template + for i, dataset in enumerate(datasets): + if dataset is None: + continue + if i == 1 and predict_with_generate: + # val_dataset + continue + if (args.model_meta.is_multimodal or args.lazy_tokenize) and not args.streaming: + dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) + if args.packing: + packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset + dataset = packing_dataset_cls( + template, + dataset, + num_proc=args.dataset_num_proc, + strict=args.strict, + load_from_cache_file=args.load_from_cache_file) + datasets[i] = dataset + self._show_dataset(*datasets) + return datasets - train_dataset, val_dataset = self._get_dataset() - train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) + def run(self): + args = self.args + train_dataset, val_dataset = self._prepare_dataset() if args.task_type == 'seq_cls': args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) @@ -202,73 +254,62 @@ def _prepare_callbacks(self): callbacks += extra_callbacks self.callbacks = callbacks - def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]): + @staticmethod + def _stat_dataset(dataset: Union[HfDataset, PackingDataset]): if isinstance(dataset, HfDataset): - # TODO: Temporary fix; awaiting template refactor. - try: - length = dataset['length'] - except KeyError: - logger.warning_once("The HfDataset is missing the 'length' column, skipping statistics.") - return + length = dataset['length'] else: - length = dataset.dataset['length'] + length = dataset.packed_length _, stat_str = stat_array(length) logger.info(f'Dataset Token Length: {stat_str}') return stat_str + def _show_dataset(self, train_dataset, val_dataset): + args = self.args + predict_with_generate = getattr(args, 'predict_with_generate', False) + if is_master(): + inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) + self.template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) + elif hasattr(train_dataset, '__len__'): + # Avoid the random mismatch issue in LazyLLMDataset. + inputs = train_dataset[0] + if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) == 0: + val_dataset = None + if not args.lazy_tokenize and not args.streaming: + self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) + if val_dataset is not None and not predict_with_generate: + self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) + def _encode_dataset(self, train_dataset, val_dataset): template = self.template args = self.args self._save_val_dataset(val_dataset) + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' predict_with_generate = getattr(args, 'predict_with_generate', False) datasets = [train_dataset, val_dataset] - if not is_grpo: - origin_template_model = template.model - template.model = None # Avoid serializing the model. - lazy_tokenize = args.lazy_tokenize and not args.packing - for i, dataset in enumerate(datasets): - if dataset is None: - continue - if i == 1 and predict_with_generate: - # val_dataset - continue - if lazy_tokenize: - dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) - elif args.packing: - packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset - dataset = packing_dataset_cls( - self.template, - dataset, - num_proc=args.dataset_num_proc, - strict=args.strict, - load_from_cache_file=args.load_from_cache_file) - else: - preprocessor = EncodePreprocessor(template=template) - dataset = preprocessor( - dataset, - num_proc=args.dataset_num_proc, - load_from_cache_file=args.load_from_cache_file, - strict=args.strict) - if args.model_meta.is_multimodal: - dataset = LazyLLMDataset(dataset, template.encode) - datasets[i] = dataset - template.model = origin_template_model - train_dataset, val_dataset = datasets - if is_master(): - inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) - template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) - elif hasattr(train_dataset, '__len__'): - # Avoid the random mismatch issue in LazyLLMDataset. - inputs = train_dataset[0] - if val_dataset is not None and hasattr(val_dataset, '__len__') and len(val_dataset) == 0: - val_dataset = None + if is_grpo: + return datasets + + origin_template_model = template.model + template.model = None # Avoid serializing the model. + for i, dataset in enumerate(datasets): + if dataset is None: + continue + if i == 1 and predict_with_generate: + # val_dataset + continue if not args.lazy_tokenize and not args.streaming: - self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) - if val_dataset is not None and not predict_with_generate: - self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) - - return train_dataset, val_dataset + preprocessor = EncodePreprocessor(template=template) + dataset = preprocessor( + dataset, + num_proc=args.dataset_num_proc, + load_from_cache_file=args.load_from_cache_file, + strict=args.strict) + datasets[i] = dataset + template.model = origin_template_model + + return datasets def sft_main(args: Union[List[str], TrainArguments, None] = None): diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 928d87086f..7fa33c90c4 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -15,8 +15,6 @@ @dataclass class MegatronTrainArguments(MegatronArguments, BaseArguments): add_version: bool = True - # dataset - lazy_tokenize: bool = False def init_model_args(self, tokenizer, config): self.megatron_model_meta = get_megatron_model_meta(self.model_type) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 7b758ad2fa..986d002786 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -48,9 +48,8 @@ def _get_data_collator(self): def run(self): args = self.args + train_dataset, val_dataset = self._prepare_dataset() data_collator = self._get_data_collator() - train_dataset, val_dataset = self._get_dataset() - train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) if args.streaming: train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) diff --git a/tests/train/test_export_cached_dataset.py b/tests/train/test_export_cached_dataset.py new file mode 100644 index 0000000000..02e9006e16 --- /dev/null +++ b/tests/train/test_export_cached_dataset.py @@ -0,0 +1,27 @@ +def test_export_cached_dataset(): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT', + to_cached_dataset=True, + dataset_num_proc=4, + )) + print() + + +def test_sft(): + from swift.llm import sft_main, TrainArguments + sft_main( + TrainArguments( + model='Qwen/Qwen2.5-7B-Instruct', + dataset='liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT#1000', + dataset_num_proc=2, + packing=True, + attn_impl='flash_attn', + )) + + +if __name__ == '__main__': + # test_export_cached_dataset() + test_sft()