Skip to content

support export cached_dataset #4992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
20 changes: 19 additions & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat
seed (int): Random seed for reproducibility. Default is 42.
model_kwargs (Optional[str]): Additional keyword arguments for the model. Default is None.
load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False.
packing (bool): Flag to enable packing of datasets. Default is False.
lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None.
use_hf (bool): Flag to determine if Hugging Face should be used. Default is False.
hub_token (Optional[str]): SDK token for authentication. Default is None.
custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None.
Expand All @@ -80,6 +82,8 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat
load_data_args: bool = False
# dataset
packing: bool = False
lazy_tokenize: Optional[bool] = None
cached_dataset: List[str] = field(default_factory=list)
custom_register_path: List[str] = field(default_factory=list) # .py
# hub
use_hf: bool = False
Expand All @@ -97,6 +101,18 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat
def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
pass

def _init_lazy_tokenize(self):
if self.streaming and self.lazy_tokenize:
self.lazy_tokenize = False
logger.warning('Streaming and lazy_tokenize are incompatible. '
f'Setting args.lazy_tokenize: {self.lazy_tokenize}.')
if self.lazy_tokenize is None:
if self.model_meta.is_multimodal and not self.streaming and not self.packing:
self.lazy_tokenize = True
else:
self.lazy_tokenize = False
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')

def _init_custom_register(self) -> None:
"""Register custom .py file to datasets"""
if isinstance(self.custom_register_path, str):
Expand Down Expand Up @@ -154,7 +170,9 @@ def __post_init__(self):
QuantizeArguments.__post_init__(self)
TemplateArguments.__post_init__(self)
DataArguments.__post_init__(self)

if isinstance(self.cached_dataset, str):
self.cached_dataset = [self.cached_dataset]
self._init_lazy_tokenize()
self.hub = get_hub(self.use_hf)
if self.hub.try_login(self.hub_token):
logger.info('hub login successful!')
Expand Down
5 changes: 5 additions & 0 deletions swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class ExportArguments(MergeArguments, BaseArguments):
quant_batch_size: int = 1
group_size: int = 128

# cached_dataset
to_cached_dataset: bool = False

# ollama
to_ollama: bool = False

Expand Down Expand Up @@ -79,6 +82,8 @@ def _init_output_dir(self):
suffix = 'mcore'
elif self.to_hf:
suffix = 'hf'
elif self.to_cached_dataset:
suffix = 'cached_dataset'
else:
return

Expand Down
13 changes: 0 additions & 13 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,13 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
Args:
add_version (bool): Flag to add version information to output_dir. Default is True.
loss_type (Optional[str]): Type of loss function to use. Default is None.
packing (bool): Flag to enable packing of datasets. Default is False.
lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None.
max_new_tokens (int): Maximum number of new tokens to generate. Default is 64.
temperature (float): Temperature for sampling. Default is 0.
optimizer (Optional[str]): Optimizer type to use, define it in the plugin package. Default is None.
metric (Optional[str]): Metric to use for evaluation, define it in the plugin package. Default is None.
"""
add_version: bool = True
create_checkpoint_symlink: bool = False
lazy_tokenize: Optional[bool] = None

# plugin
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
Expand All @@ -135,15 +132,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
# auto_tp
deepspeed_autotp_size: Optional[int] = None

def _init_lazy_tokenize(self):
if self.streaming and self.lazy_tokenize:
self.lazy_tokenize = False
logger.warning('Streaming and lazy_tokenize are incompatible. '
f'Setting args.lazy_tokenize: {self.lazy_tokenize}.')
if self.lazy_tokenize is None:
self.lazy_tokenize = self.model_meta.is_multimodal and not self.streaming
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')

def __post_init__(self) -> None:
if self.padding_free or self.packing:
if self.packing:
Expand Down Expand Up @@ -179,7 +167,6 @@ def __post_init__(self) -> None:

self._init_deepspeed()
self._init_device()
self._init_lazy_tokenize()

if getattr(self, 'accelerator_config', None) is None:
self.accelerator_config = {'dispatch_batches': False}
Expand Down
28 changes: 11 additions & 17 deletions swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,23 +146,18 @@ def __init__(
self.strict = strict
self.load_from_cache_file = load_from_cache_file
self.workers = []
preprocessor = EncodePreprocessor(template=template)
self.dataset = preprocessor(
dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict)
if template.model_meta.is_multimodal:
self.dataset = LazyLLMDataset(self.dataset, encode_func=template.encode)
self.packed_idx = self.create_packed_idx() if is_master() else None
self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else None
if dist.is_initialized() and is_dist():
obj_list = [self.packed_idx]
obj_list = [(self.packed_idx, self.packed_length)]
dist.broadcast_object_list(obj_list)
self.packed_idx = obj_list[0]
self.packed_idx, self.packed_length = obj_list[0]

def create_packed_idx(self):
lengths = self.dataset['length']
data = [(i, length) for i, length in enumerate(lengths)]
i = 0
PACKING_BATCH_SIZE = 1000
input_data, res = [], []
input_data, packed_idx, packed_length = [], [], []
with tqdm(total=len(data), dynamic_ncols=True, desc='Packing: ') as prog_bar:
while True:
new_data = data[i:i + PACKING_BATCH_SIZE]
Expand All @@ -173,14 +168,13 @@ def create_packed_idx(self):
i += PACKING_BATCH_SIZE
is_finished = i >= len(data)
sequences, input_data = calculate_matched_group(self.template, input_data, is_finished=is_finished)
res += sequences
return res
packed_idx += [[x[0] for x in seq] for seq in sequences]
packed_length += [sum(x[1] for x in seq) for seq in sequences]
return packed_idx, packed_length

def __getitem__(self, index):
sequence = self.packed_idx[index]
row = []
for i, length in sequence:
row.append((self.dataset[i], length))
row = [self.dataset[i] for i in sequence]
return self.template.packing_row(row)

def __len__(self):
Expand Down Expand Up @@ -221,7 +215,7 @@ def _processor(self):
i, data = self._in_queue.get()
encoded_data = {}
try:
encoded_data = self.template.encode(data)
encoded_data = self.template.encode(data, return_length=True)
except Exception as e:
if self.strict and not isinstance(e, MaxLengthError):
raise
Expand Down Expand Up @@ -271,7 +265,7 @@ def __iter__(self):
sequences, data = calculate_matched_group(self.template, data, is_finished=finished)
res = []
for row in sequences:
packed = self.template.packing_row(row)
packed = self.template.packing_row([r[0] for r in row])
res.append(packed)
yield from res
if finished:
Expand All @@ -286,7 +280,7 @@ def __init__(self, template: 'Template'):
self.is_multimodal = template.model_meta.is_multimodal

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
encoded = self.template.encode(row)
encoded = self.template.encode(row, return_length=True)
if self.is_multimodal:
row['length'] = encoded['length']
encoded = row
Expand Down
1 change: 1 addition & 0 deletions swift/llm/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .cached_dataset import export_cached_dataset
from .export import SwiftExport, export_main
from .merge_lora import merge_lora
from .ollama import export_to_ollama
Expand Down
35 changes: 35 additions & 0 deletions swift/llm/export/cached_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Union

from swift.llm import ExportArguments
from swift.llm.train import SwiftSft
from swift.utils import get_logger

logger = get_logger()


class ExportCachedDataset(SwiftSft):
args_class = ExportArguments
args: args_class

def __init__(self, args: Union[List[str], ExportArguments, None] = None) -> None:
super(SwiftSft, self).__init__(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of super(SwiftSft, self) bypasses the SwiftSft parent's __init__ to call SwiftPipeline.__init__. While functional, this is not idiomatic in Python 3 and can be confusing. For better readability, it's recommended to call the grandparent's __init__ explicitly. This makes the intent of skipping the direct parent's initializer clear. You'll need to add from swift.llm import SwiftPipeline to your imports.

Suggested change
super(SwiftSft, self).__init__(args)
from swift.llm import SwiftPipeline
SwiftPipeline.__init__(self, args)

self.train_msg = {} # dummy
self.processor = None
self._prepare_template()
self._prepare_model_tokenizer(load_model=self.template.use_model)
self.template.init_processor(self.processor)

def main(self):
train_dataset, val_dataset = self._get_dataset()
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
self._show_dataset(train_dataset, val_dataset)
train_dataset.save_to_disk(os.path.join(self.args.output_dir, 'train'))
if val_dataset is not None:
val_dataset.save_to_disk(os.path.join(self.args.output_dir, 'val'))
logger.info(f'Dataset saved to `{self.args.output_dir}`')


def export_cached_dataset(args: Union[List[str], ExportArguments, None] = None):
return ExportCachedDataset(args).main()
3 changes: 3 additions & 0 deletions swift/llm/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from swift.llm import ExportArguments, SwiftPipeline
from swift.tuners import swift_to_peft_format
from swift.utils import get_logger
from .cached_dataset import export_cached_dataset
from .merge_lora import merge_lora
from .ollama import export_to_ollama
from .quant import quantize_model
Expand All @@ -29,6 +30,8 @@ def run(self):
quantize_model(args)
elif args.to_ollama:
export_to_ollama(args)
elif args.to_cached_dataset:
export_cached_dataset(args)
elif args.to_mcore:
from swift.megatron import convert_hf2mcore
convert_hf2mcore(args)
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
prog_bar.close()
metrics = self.infer_kwargs.pop('metrics')
if result_list:
print(f'[rank{args.rank}] {metrics[0].compute()}')
metric = metrics[0].compute()
print(f'[rank{args.rank}] {metric}' if args.rank >= 0 else str(metric))
if args.metric is not None:
self._calc_metric()
return result_list
Expand Down
23 changes: 13 additions & 10 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
from .template_meta import TemplateMeta
from swift.plugin import agent_templates, loss_scale_map
self._processor_inited = False
self._version = 'v1' # Avoid compatibility issues caused by load_from_cache_file caching.
self._version = 'v2' # Avoid compatibility issues caused by load_from_cache_file caching.
self.max_length = max_length
self.model = None

Expand Down Expand Up @@ -488,7 +488,8 @@ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
@torch.inference_mode()
def encode(self,
inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
return_template_inputs: bool = False) -> Dict[str, Any]:
return_template_inputs: bool = False,
return_length: bool = False) -> Dict[str, Any]:
"""The entrance method of Template!

Returns:
Expand Down Expand Up @@ -537,7 +538,7 @@ def encode(self,
lengths.append(value)
elif isinstance(value, (tuple, list)):
lengths += value
if self.is_training:
if return_length:
encoded['length'] = max(lengths)
else:
encoded.pop('length', None)
Expand All @@ -547,22 +548,24 @@ def encode(self,
encoded['_extra_kwargs'] = extra_kwargs
return encoded

def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
packed = {}
keys = set()
length = []
for r in row:
keys.update(r[0].keys())
keys.update(r.keys())
length.append(r['length'])
for key in keys:
if key in {'input_ids', 'labels', 'loss_scale'}:
packed[key] = sum((x[0][key] for x in row), start=[])
packed[key] = sum((x[key] for x in row), start=[])
elif key == 'length':
packed[key] = sum((x[0][key] for x in row))
packed[key] = sum((x[key] for x in row))
elif key == 'channel':
packed[key] = [x[0][key] for x in row]
packed[key] = [x[key] for x in row]
if 'position_ids' not in packed:
packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
packed['position_ids'] = sum((list(range(x)) for x in length), start=[])

packed.update(self._data_collator_mm_data([r[0] for r in row]))
packed.update(self._data_collator_mm_data(row))
return packed

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,10 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
res[f'{media_type}_grid_thw'] = grid_thw
return res

def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
position_ids = []
for r in row:
r = r[0].copy()
r = r.copy()
r['input_ids'] = torch.tensor(r['input_ids'])[None]
position_ids.append(self._get_position_ids(r))
packed = super().packing_row(row)
Expand Down
Loading
Loading