Skip to content

Commit fc920f4

Browse files
committed
Refactor: Recording and logging training and evaluation metrics in all trainers
1 parent fa931f2 commit fc920f4

File tree

8 files changed

+177
-273
lines changed

8 files changed

+177
-273
lines changed

MaxText/elastic_train.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@
3838
for more details about the elastic manager.
3939
"""
4040
from collections.abc import Sequence
41-
import datetime
4241
import logging
4342
import os
4443
import sys
4544
import time
46-
import queue
4745

4846
from absl import app
4947

@@ -68,13 +66,11 @@
6866
from MaxText import max_logging
6967
from MaxText import profiler
7068
from MaxText import pyconfig
71-
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
7269
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
7370
from MaxText.metric_logger import MetricLogger
7471
from MaxText.train import check_example_batch
7572
from MaxText.train import get_first_step
7673
from MaxText.train import load_next_batch
77-
from MaxText.train import record_scalar_metrics
7874
from MaxText.train import save_checkpoint
7975
from MaxText.train import setup_mesh_and_model
8076
from MaxText.train import setup_train_loop
@@ -120,7 +116,7 @@ def elastic_handler(
120116
checkpoint_manager.close()
121117

122118
with jax.default_device(elastic_manager.default_device):
123-
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
119+
init_rng, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
124120
config, elastic_manager.good_devices
125121
)
126122
with mesh:
@@ -171,7 +167,7 @@ def elastic_handler(
171167
)
172168

173169
example_batch = None
174-
metric_logger = MetricLogger(writer, config)
170+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
175171

176172
jax.block_until_ready(state)
177173

@@ -186,15 +182,13 @@ def elastic_handler(
186182
example_batch,
187183
learning_rate_schedule,
188184
metric_logger,
189-
writer,
190185
)
191186

192187

193188
def train_loop(config, elastic_manager, recorder, state=None):
194189
"""Main Training loop."""
195190
(
196191
init_rng,
197-
writer,
198192
checkpoint_manager,
199193
state_mesh_shardings,
200194
model,
@@ -214,25 +208,13 @@ def train_loop(config, elastic_manager, recorder, state=None):
214208
donate_argnums_train,
215209
) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_shardings, model, config)
216210

217-
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
218-
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
219-
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
220-
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)
221-
222-
# Write train config params, num model params, and XLA flags to tensorboard
223-
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
224-
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
225-
maxtext_utils.add_config_to_summary_writer(config, writer)
226-
227211
p_train_step = jax.jit(
228212
functional_train,
229213
in_shardings=in_shard_train,
230214
out_shardings=out_shard_train,
231215
static_argnums=static_argnums_train,
232216
donate_argnums=donate_argnums_train,
233217
)
234-
running_gcs_metrics = [] if config.gcs_metrics else None
235-
metrics = None
236218

237219
start_step = get_first_step(state) # this is the start_step for training
238220
prof = profiler.Profiler(config, offset_step=start_step)
@@ -242,18 +224,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
242224
last_profiling_step = prof.finished_initial_profile_step
243225

244226
example_batch = None
245-
last_step_completion = datetime.datetime.now()
246-
247-
performance_metric_queue = None
248-
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
249-
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
250-
if config.report_heartbeat_metric_for_gcp_monitoring:
251-
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
252-
if config.report_performance_metric_for_gcp_monitoring:
253-
performance_metric_queue = queue.Queue()
254-
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)
255-
256-
metric_logger = MetricLogger(writer, config)
257227
step = start_step
258228

259229
elastic_manager.maybe_snapshot(
@@ -267,6 +237,12 @@ def train_loop(config, elastic_manager, recorder, state=None):
267237
)
268238

269239
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
240+
241+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
242+
243+
# Write train config params, num model params, and XLA flags to tensorboard
244+
metric_logger.write_setup_info_to_tensorboard(state.params)
245+
270246
# Using while loop instead of a for loop because with elasticity
271247
# the step is restored back to the latest snapshot when a slice is lost
272248
while step < config.steps:
@@ -292,11 +268,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
292268
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
293269
state, metrics = p_train_step(state, example_batch, nextrng)
294270

295-
step_time_delta = datetime.datetime.now() - last_step_completion
296-
last_step_completion = datetime.datetime.now()
297-
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
298-
if performance_metric_queue:
299-
performance_metric_queue.put(step_time_delta.total_seconds())
271+
metric_logger.record_train_metrics(metrics, step)
300272

301273
if checkpoint_manager is not None:
302274
state_to_save = state
@@ -308,8 +280,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
308280
checkpoint_manager.wait_until_finished()
309281
sys.exit()
310282

311-
metric_logger.write_metrics(running_gcs_metrics, metrics, step)
312-
313283
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
314284
prof.deactivate(blocking_object=state)
315285

@@ -347,7 +317,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
347317
example_batch,
348318
learning_rate_schedule,
349319
metric_logger,
350-
writer,
351320
) = ret
352321

353322
if step == start_step:
@@ -377,7 +346,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
377346
example_batch,
378347
learning_rate_schedule,
379348
metric_logger,
380-
writer,
381349
) = ret
382350

383351
if checkpoint_manager is not None:
@@ -398,8 +366,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
398366
max_logging.log(f"Checkpoint is already saved for step {int(state.step)-1}.")
399367

400368
checkpoint_manager.wait_until_finished()
401-
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
402-
max_utils.close_summary_writer(writer)
369+
metric_logger.flush_metrics_and_cleanup()
403370

404371
if example_batch:
405372
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@
6060
from MaxText.common_types import Array
6161
from MaxText.experimental.rl import grpo_input_pipeline
6262
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
63+
from MaxText.globals import EPS
6364
from MaxText.layers import models
6465
from MaxText.metric_logger import MetricLogger
6566
from MaxText.train import (
6667
validate_train_config,
6768
get_first_step,
6869
load_next_batch,
69-
record_scalar_metrics,
7070
save_checkpoint,
7171
check_example_batch,
7272
setup_mesh_and_model,
@@ -82,7 +82,6 @@
8282
# pylint: disable=too-many-positional-arguments
8383

8484
Transformer = models.Transformer
85-
EPS = 1e-8
8685

8786

8887
# -----------------------------------------------------------------------------
@@ -676,7 +675,6 @@ def train_loop(config, config_inference, recorder, state=None):
676675
"""Main Training loop."""
677676
(
678677
init_rng,
679-
writer,
680678
checkpoint_manager,
681679
state_mesh_shardings,
682680
model,
@@ -727,15 +725,10 @@ def train_loop(config, config_inference, recorder, state=None):
727725
) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_shardings, model, config)
728726

729727
# TODO: fix tflops calculations for grpo setting
730-
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
731-
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
732-
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
733-
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)
728+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
734729

735730
# Write train config params, num model params, and XLA flags to tensorboard
736-
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
737-
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
738-
maxtext_utils.add_config_to_summary_writer(config, writer)
731+
metric_logger.write_setup_info_to_tensorboard(state.params)
739732

740733
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
741734
if config.compiled_trainstep_file != "":
@@ -774,9 +767,6 @@ def train_loop(config, config_inference, recorder, state=None):
774767
donate_argnums=(0,),
775768
)
776769

777-
running_gcs_metrics = [] if config.gcs_metrics else None
778-
metrics = None
779-
780770
start_step = get_first_step(state) # this is the start_step for training
781771
prof = profiler.Profiler(config, offset_step=start_step)
782772
first_profiling_step = prof.start_initial_profile_step
@@ -785,19 +775,14 @@ def train_loop(config, config_inference, recorder, state=None):
785775
last_profiling_step = prof.finished_initial_profile_step
786776

787777
example_batch = None
788-
last_step_completion = datetime.datetime.now()
789-
790-
performance_metric_queue = None
791-
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
792-
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
793-
if config.report_heartbeat_metric_for_gcp_monitoring:
794-
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
795-
if config.report_performance_metric_for_gcp_monitoring:
796-
performance_metric_queue = queue.Queue()
797-
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)
798-
799-
metric_logger = MetricLogger(writer, config)
778+
800779
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
780+
781+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
782+
783+
# Write train config params, num model params, and XLA flags to tensorboard
784+
metric_logger.write_setup_info_to_tensorboard(state.params)
785+
801786
for step in np.arange(start_step, config.steps):
802787
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
803788
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
@@ -823,11 +808,7 @@ def train_loop(config, config_inference, recorder, state=None):
823808
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
824809
state, metrics = p_train_step(state, example_batch, rng)
825810

826-
step_time_delta = datetime.datetime.now() - last_step_completion
827-
last_step_completion = datetime.datetime.now()
828-
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
829-
if performance_metric_queue:
830-
performance_metric_queue.put(step_time_delta.total_seconds())
811+
metric_logger.record_train_metrics(metrics, step)
831812

832813
if checkpoint_manager is not None:
833814
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
@@ -839,8 +820,6 @@ def train_loop(config, config_inference, recorder, state=None):
839820
checkpoint_manager.wait_until_finished()
840821
sys.exit()
841822

842-
metric_logger.write_metrics(running_gcs_metrics, metrics, step)
843-
844823
if config.dump_hlo and step == start_step:
845824
jax.block_until_ready(state) # Ensure compilation has finished.
846825
max_utils.upload_dump(
@@ -853,43 +832,18 @@ def train_loop(config, config_inference, recorder, state=None):
853832

854833
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
855834
assert eval_data_iterator
856-
cumulative_eval_metrics = {
857-
"scalar": {
858-
"eval/total_loss": 0.0,
859-
"eval/total_weights": 0.0,
860-
"eval/avg_loss": 0.0,
861-
"eval/moe_lb_loss": 0.0,
862-
}
863-
}
864-
eval_dpo_reward_accuracy = 0.0
865835
eval_step_count = 0
866836
# pylint: disable=not-callable
867837
for eval_batch in eval_data_iterator:
868838
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
869839
break
870840
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
871841
eval_metrics = p_eval_step(state, eval_batch, rng)
872-
cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"])
873-
cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"])
874-
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"])
875-
eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only
842+
metric_logger.record_eval_metrics(metrics=eval_metrics, eval_step_count=None, step=step)
876843
max_logging.log(f"Completed eval step {eval_step_count}")
877844
eval_step_count += 1
878-
eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / (
879-
cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS
880-
)
881-
cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss
882-
cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
883-
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
884-
)
885-
if config.use_dpo:
886-
cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count
887-
metric_logger.write_metrics(running_gcs_metrics, cumulative_eval_metrics, step, is_training=False)
888-
max_logging.log(
889-
f"average loss after {step=}: {eval_step_count=}, {eval_loss=},"
890-
f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}"
891-
)
892-
if eval_loss <= config.target_eval_loss:
845+
metric_logger.record_eval_metrics(metrics=None, eval_step_count=eval_step_count, step=step)
846+
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
893847
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
894848
prof.deactivate()
895849
break
@@ -911,8 +865,8 @@ def train_loop(config, config_inference, recorder, state=None):
911865
max_logging.log(f"Checkpoint already saved for step {int(state.step)-1}.")
912866

913867
checkpoint_manager.wait_until_finished()
914-
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
915-
max_utils.close_summary_writer(writer)
868+
metric_logger.flush_metrics_and_cleanup()
869+
916870
if example_batch:
917871
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
918872
compiled = p_train_step.lower(state, example_batch, rng).compile()

MaxText/globals.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import os.path
1818

19-
PKG_DIR = os.path.dirname(os.path.abspath(__file__))
19+
PKG_DIR = os.path.dirname(os.path.abspath(__file__)) # MaxText directory path
20+
EPS = 1e-8 # Epsilon to calculate loss
21+
DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 # Default checkpoint file size
2022

21-
__all__ = ["PKG_DIR"]
23+
__all__ = ["DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE", "EPS", "PKG_DIR"]

0 commit comments

Comments
 (0)