Skip to content

Commit fa931f2

Browse files
author
maxtext authors
committed
Merge pull request #1803 from AI-Hypercomputer:input_batch_none
PiperOrigin-RevId: 769297433
2 parents 4a1e40e + ac511b5 commit fa931f2

File tree

4 files changed

+69
-48
lines changed

4 files changed

+69
-48
lines changed

MaxText/elastic_train.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
232232
donate_argnums=donate_argnums_train,
233233
)
234234
running_gcs_metrics = [] if config.gcs_metrics else None
235+
metrics = None
235236

236237
start_step = get_first_step(state) # this is the start_step for training
237238
prof = profiler.Profiler(config, offset_step=start_step)
@@ -380,7 +381,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
380381
) = ret
381382

382383
if checkpoint_manager is not None:
383-
if (int(state.step) - 1) % config.checkpoint_period != 0:
384+
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
384385
try:
385386
state_to_save = state
386387
if save_checkpoint(
@@ -400,17 +401,19 @@ def train_loop(config, elastic_manager, recorder, state=None):
400401
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
401402
max_utils.close_summary_writer(writer)
402403

403-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
404-
# pytype: disable=attribute-error
405-
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
406-
compiled_stats = compiled.memory_analysis()
407-
if compiled_stats is not None:
408-
max_logging.log(
409-
f"Output size: {compiled_stats.output_size_in_bytes}, "
410-
f"temp size: {compiled_stats.temp_size_in_bytes}, "
411-
f"argument size: {compiled_stats.argument_size_in_bytes}, "
412-
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
413-
)
404+
if example_batch:
405+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
406+
# pytype: disable=attribute-error
407+
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
408+
compiled_stats = compiled.memory_analysis()
409+
if compiled_stats is not None:
410+
max_logging.log(
411+
f"Output size: {compiled_stats.output_size_in_bytes}, "
412+
f"temp size: {compiled_stats.temp_size_in_bytes}, "
413+
f"argument size: {compiled_stats.argument_size_in_bytes}, "
414+
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
415+
)
416+
414417
return state
415418

416419

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ def train_loop(config, config_inference, recorder, state=None):
775775
)
776776

777777
running_gcs_metrics = [] if config.gcs_metrics else None
778+
metrics = None
778779

779780
start_step = get_first_step(state) # this is the start_step for training
780781
prof = profiler.Profiler(config, offset_step=start_step)
@@ -804,8 +805,12 @@ def train_loop(config, config_inference, recorder, state=None):
804805

805806
with jax.profiler.StepTraceAnnotation("train", step_num=step):
806807
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
807-
example_batch = load_next_batch(data_iterator, example_batch, config)
808-
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
808+
try:
809+
example_batch = load_next_batch(data_iterator, example_batch, config)
810+
example_batch = jax.lax.with_sharding_constraint(example_batch, input_data_shardings)
811+
except Exception as e: # pylint: disable=broad-except
812+
max_logging.log(f"load_next_batch failed, you may have run out of data. Error message: {e}")
813+
break
809814

810815
check_example_batch(config, example_batch=example_batch)
811816
# pylint: disable=not-callable
@@ -896,20 +901,29 @@ def train_loop(config, config_inference, recorder, state=None):
896901
max_utils.print_mem_stats("After params initialized")
897902

898903
if checkpoint_manager is not None:
904+
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
905+
try:
906+
if save_checkpoint(
907+
checkpoint_manager, int(state.step) - 1, state, config.dataset_type, data_iterator, config, force=True
908+
):
909+
checkpointing.print_save_message(int(state.step) - 1, config.async_checkpointing)
910+
except Exception: # pylint: disable=broad-except
911+
max_logging.log(f"Checkpoint already saved for step {int(state.step)-1}.")
912+
899913
checkpoint_manager.wait_until_finished()
900914
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
901915
max_utils.close_summary_writer(writer)
902-
903-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
904-
compiled = p_train_step.lower(state, example_batch, rng).compile()
905-
compiled_stats = compiled.memory_analysis()
906-
if compiled_stats is not None:
907-
max_logging.log(
908-
f"Output size: {compiled_stats.output_size_in_bytes}, "
909-
f"temp size: {compiled_stats.temp_size_in_bytes}, "
910-
f"argument size: {compiled_stats.argument_size_in_bytes}, "
911-
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
912-
)
916+
if example_batch:
917+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
918+
compiled = p_train_step.lower(state, example_batch, rng).compile()
919+
compiled_stats = compiled.memory_analysis()
920+
if compiled_stats is not None:
921+
max_logging.log(
922+
f"Output size: {compiled_stats.output_size_in_bytes}, "
923+
f"temp size: {compiled_stats.temp_size_in_bytes}, "
924+
f"argument size: {compiled_stats.argument_size_in_bytes}, "
925+
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
926+
)
913927
return state
914928

915929

MaxText/sft_trainer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def train_loop(config, recorder, state=None):
133133
)
134134

135135
running_gcs_metrics = [] if config.gcs_metrics else None
136+
metrics = None
136137

137138
start_step = get_first_step(state) # this is the start_step for training
138139
prof = profiler.Profiler(config, offset_step=start_step)
@@ -249,7 +250,7 @@ def train_loop(config, recorder, state=None):
249250
max_utils.print_mem_stats("After params initialized")
250251

251252
if checkpoint_manager is not None:
252-
if (int(state.step) - 1) % config.checkpoint_period != 0:
253+
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
253254
try:
254255
if save_checkpoint(
255256
checkpoint_manager, int(state.step) - 1, state, config.dataset_type, data_iterator, config, force=True
@@ -262,16 +263,17 @@ def train_loop(config, recorder, state=None):
262263
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
263264
max_utils.close_summary_writer(writer)
264265

265-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
266-
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
267-
compiled_stats = compiled.memory_analysis()
268-
if compiled_stats is not None:
269-
max_logging.log(
270-
f"Output size: {compiled_stats.output_size_in_bytes}, "
271-
f"temp size: {compiled_stats.temp_size_in_bytes}, "
272-
f"argument size: {compiled_stats.argument_size_in_bytes}, "
273-
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
274-
)
266+
if example_batch:
267+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
268+
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
269+
compiled_stats = compiled.memory_analysis()
270+
if compiled_stats is not None:
271+
max_logging.log(
272+
f"Output size: {compiled_stats.output_size_in_bytes}, "
273+
f"temp size: {compiled_stats.temp_size_in_bytes}, "
274+
f"argument size: {compiled_stats.argument_size_in_bytes}, "
275+
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
276+
)
275277
return state
276278

277279

MaxText/train.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ def train_loop(config, recorder, state=None):
814814
p_eval_step = None
815815

816816
running_gcs_metrics = [] if config.gcs_metrics else None
817+
metrics = None
817818

818819
start_step = get_first_step(state) # this is the start_step for training
819820
prof = profiler.Profiler(config, offset_step=start_step)
@@ -936,7 +937,7 @@ def train_loop(config, recorder, state=None):
936937
max_utils.print_mem_stats("After params initialized")
937938

938939
if checkpoint_manager is not None:
939-
if (int(state.step) - 1) % config.checkpoint_period != 0:
940+
if ((int(state.step) - 1) % config.checkpoint_period != 0) and (int(state.step) != 0):
940941
try:
941942
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
942943
if save_checkpoint(
@@ -956,17 +957,18 @@ def train_loop(config, recorder, state=None):
956957
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
957958
max_utils.close_summary_writer(writer)
958959

959-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
960-
# pytype: disable=attribute-error
961-
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
962-
compiled_stats = compiled.memory_analysis()
963-
if compiled_stats is not None:
964-
max_logging.log(
965-
f"Output size: {compiled_stats.output_size_in_bytes}, "
966-
f"temp size: {compiled_stats.temp_size_in_bytes}, "
967-
f"argument size: {compiled_stats.argument_size_in_bytes}, "
968-
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
969-
)
960+
if example_batch:
961+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
962+
# pytype: disable=attribute-error
963+
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
964+
compiled_stats = compiled.memory_analysis()
965+
if compiled_stats is not None:
966+
max_logging.log(
967+
f"Output size: {compiled_stats.output_size_in_bytes}, "
968+
f"temp size: {compiled_stats.temp_size_in_bytes}, "
969+
f"argument size: {compiled_stats.argument_size_in_bytes}, "
970+
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
971+
)
970972
return state
971973

972974

0 commit comments

Comments
 (0)