Skip to content

Commit f827a03

Browse files
author
maxtext authors
committed
Merge pull request #1801 from bzantium:feature/#1610
PiperOrigin-RevId: 772620051
2 parents c551e72 + 2fa06eb commit f827a03

File tree

5 files changed

+75
-7
lines changed

5 files changed

+75
-7
lines changed

MaxText/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,11 @@ hf_eval_split: ''
438438
hf_eval_files: ''
439439
hf_access_token: ''
440440
# for Grain input pipeline (dataset_type=grain)
441+
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
442+
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
443+
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
444+
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
445+
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
441446
grain_train_files: ''
442447
grain_eval_files: ''
443448
grain_file_type: 'arrayrecord' # arrayrecord or parquet

MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333
from MaxText import tokenizer
3434

3535

36+
def find_data_files(data_file_pattern):
37+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
38+
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
39+
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
40+
return data_files
41+
42+
3643
def get_datasets(
3744
data_file_pattern,
3845
data_file_type,
@@ -44,17 +51,26 @@ def get_datasets(
4451
grain_worker_count,
4552
):
4653
"""Load dataset from array_record files for using with grain"""
47-
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
48-
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
49-
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
5054
if data_file_type == "arrayrecord":
51-
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
55+
if ";" in data_file_pattern:
56+
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
57+
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
58+
weights = [float(weight) for weight in weights]
59+
weights = [round(weight / sum(weights), 4) for weight in weights]
60+
dataset_list = [
61+
grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns
62+
]
63+
dataset = grain.MapDataset.mix(dataset_list, weights)
64+
else:
65+
data_files = find_data_files(data_file_pattern)
66+
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
5267
if shuffle:
5368
dataset = dataset.shuffle(seed=shuffle_seed)
5469
dataset = dataset.repeat(num_epoch)
5570
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
5671
dataset = dataset.to_iter_dataset()
5772
elif data_file_type == "parquet":
73+
data_files = find_data_files(data_file_pattern)
5874
dataset = grain.MapDataset.source(data_files)
5975
if shuffle:
6076
dataset = dataset.shuffle(seed=shuffle_seed)

MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def map(self, element):
332332

333333
@dataclasses.dataclass
334334
class Rekey(grain.MapTransform):
335-
"""Rname keys according to a mappign dict"""
335+
"""Rename keys according to a mapping dict"""
336336

337337
def __init__(self, mapping_dict, keep_old_keys=False):
338338
self.mapping_dict = mapping_dict

MaxText/tests/grain_data_processing_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@ def get_first_batch(iterator):
105105
self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all())
106106

107107

108+
class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest):
109+
110+
def setUp(self):
111+
super().setUp()
112+
temp_dir = tempfile.gettempdir()
113+
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
114+
grain_train_files = [
115+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
116+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
117+
]
118+
grain_train_files = ";".join(grain_train_files)
119+
self.config = pyconfig.initialize(
120+
[sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")],
121+
per_device_batch_size=1,
122+
run_name="test",
123+
mesh_axes=["data"],
124+
logical_axis_rules=[["batch", "data"]],
125+
data_sharding=["data"],
126+
base_output_directory="gs://max-experiments/",
127+
dataset_type="grain",
128+
grain_train_files=grain_train_files,
129+
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer"),
130+
enable_checkpointing=False,
131+
)
132+
self.mesh_shape_1d = (len(jax.devices()),)
133+
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
134+
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
135+
self.config.data_sharding,
136+
self.config.global_batch_size_to_load,
137+
self.config.global_batch_size_to_train_on,
138+
self.config.max_target_length,
139+
self.mesh,
140+
)
141+
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
142+
143+
108144
class GrainParquetProcessingTest(unittest.TestCase):
109145

110146
@classmethod

getting_started/Data_Input_Pipeline.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,18 @@ bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH [FI
102102
```
103103
3. Set `dataset_type=grain` and set `grain_train_files` to match the ArrayRecord files via a local path since the bucket has been mounted.
104104
4. Tune `grain_worker_count` for performance. This parameter controls the number of child process used by Grain (more details in [behind_the_scene](https://github.com/google/grain/blob/main/docs/behind_the_scenes.md), [code](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, please check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/setup_gcsfuse.sh) to avoid gcsfuse throttling.
105-
5. Example command:
105+
106+
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
107+
```
108+
# Blend two data sources with 30% from first source and 70% from second source
109+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7
110+
111+
# Blend three data sources with equal weights (will be normalized to 0.33 each)
112+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
113+
```
114+
Note: When using multiple data sources, only ArrayRecord format is supported.
115+
116+
6. Example command:
106117
```
107118
bash setup_gcsfuse.sh \
108119
DATASET_GCS_BUCKET=maxtext-dataset \
@@ -114,7 +125,7 @@ grain_file_type=arrayrecord \
114125
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
115126
grain_worker_count=2
116127
```
117-
6. Using validation set for eval
128+
7. Using validation set for eval
118129
When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config:
119130
```
120131
eval_interval: 10000

0 commit comments

Comments
 (0)