Skip to content

Commit a4e0fbc

Browse files
committed
Refactor: Recording and logging training and evaluation metrics in all trainers
1 parent fa931f2 commit a4e0fbc

File tree

8 files changed

+202
-287
lines changed

8 files changed

+202
-287
lines changed

MaxText/elastic_train.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import os
4444
import sys
4545
import time
46-
import queue
4746

4847
from absl import app
4948

@@ -68,13 +67,11 @@
6867
from MaxText import max_logging
6968
from MaxText import profiler
7069
from MaxText import pyconfig
71-
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
7270
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
7371
from MaxText.metric_logger import MetricLogger
7472
from MaxText.train import check_example_batch
7573
from MaxText.train import get_first_step
7674
from MaxText.train import load_next_batch
77-
from MaxText.train import record_scalar_metrics
7875
from MaxText.train import save_checkpoint
7976
from MaxText.train import setup_mesh_and_model
8077
from MaxText.train import setup_train_loop
@@ -120,7 +117,7 @@ def elastic_handler(
120117
checkpoint_manager.close()
121118

122119
with jax.default_device(elastic_manager.default_device):
123-
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
120+
init_rng, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
124121
config, elastic_manager.good_devices
125122
)
126123
with mesh:
@@ -171,7 +168,7 @@ def elastic_handler(
171168
)
172169

173170
example_batch = None
174-
metric_logger = MetricLogger(writer, config)
171+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
175172

176173
jax.block_until_ready(state)
177174

@@ -186,15 +183,13 @@ def elastic_handler(
186183
example_batch,
187184
learning_rate_schedule,
188185
metric_logger,
189-
writer,
190186
)
191187

192188

193189
def train_loop(config, elastic_manager, recorder, state=None):
194190
"""Main Training loop."""
195191
(
196192
init_rng,
197-
writer,
198193
checkpoint_manager,
199194
state_mesh_shardings,
200195
model,
@@ -214,25 +209,13 @@ def train_loop(config, elastic_manager, recorder, state=None):
214209
donate_argnums_train,
215210
) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_shardings, model, config)
216211

217-
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
218-
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
219-
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
220-
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)
221-
222-
# Write train config params, num model params, and XLA flags to tensorboard
223-
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
224-
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
225-
maxtext_utils.add_config_to_summary_writer(config, writer)
226-
227212
p_train_step = jax.jit(
228213
functional_train,
229214
in_shardings=in_shard_train,
230215
out_shardings=out_shard_train,
231216
static_argnums=static_argnums_train,
232217
donate_argnums=donate_argnums_train,
233218
)
234-
running_gcs_metrics = [] if config.gcs_metrics else None
235-
metrics = None
236219

237220
start_step = get_first_step(state) # this is the start_step for training
238221
prof = profiler.Profiler(config, offset_step=start_step)
@@ -242,18 +225,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
242225
last_profiling_step = prof.finished_initial_profile_step
243226

244227
example_batch = None
245-
last_step_completion = datetime.datetime.now()
246-
247-
performance_metric_queue = None
248-
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
249-
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
250-
if config.report_heartbeat_metric_for_gcp_monitoring:
251-
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
252-
if config.report_performance_metric_for_gcp_monitoring:
253-
performance_metric_queue = queue.Queue()
254-
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)
255-
256-
metric_logger = MetricLogger(writer, config)
257228
step = start_step
258229

259230
elastic_manager.maybe_snapshot(
@@ -267,10 +238,17 @@ def train_loop(config, elastic_manager, recorder, state=None):
267238
)
268239

269240
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
241+
242+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
243+
244+
# Write train config params, num model params, and XLA flags to tensorboard
245+
metric_logger.write_setup_info_to_tensorboard(state.params)
246+
270247
# Using while loop instead of a for loop because with elasticity
271248
# the step is restored back to the latest snapshot when a slice is lost
272249
while step < config.steps:
273250
try:
251+
train_step_start_time = datetime.datetime.now()
274252
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
275253
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
276254
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
@@ -292,12 +270,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
292270
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
293271
state, metrics = p_train_step(state, example_batch, nextrng)
294272

295-
step_time_delta = datetime.datetime.now() - last_step_completion
296-
last_step_completion = datetime.datetime.now()
297-
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
298-
if performance_metric_queue:
299-
performance_metric_queue.put(step_time_delta.total_seconds())
300-
301273
if checkpoint_manager is not None:
302274
state_to_save = state
303275
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
@@ -308,7 +280,8 @@ def train_loop(config, elastic_manager, recorder, state=None):
308280
checkpoint_manager.wait_until_finished()
309281
sys.exit()
310282

311-
metric_logger.write_metrics(running_gcs_metrics, metrics, step)
283+
train_step_time_delta = datetime.datetime.now() - train_step_start_time
284+
metric_logger.record_train_metrics(metrics, step, train_step_time_delta)
312285

313286
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
314287
prof.deactivate(blocking_object=state)
@@ -347,7 +320,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
347320
example_batch,
348321
learning_rate_schedule,
349322
metric_logger,
350-
writer,
351323
) = ret
352324

353325
if step == start_step:
@@ -377,7 +349,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
377349
example_batch,
378350
learning_rate_schedule,
379351
metric_logger,
380-
writer,
381352
) = ret
382353

383354
if checkpoint_manager is not None:
@@ -398,8 +369,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
398369
max_logging.log(f"Checkpoint is already saved for step {int(state.step)-1}.")
399370

400371
checkpoint_manager.wait_until_finished()
401-
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
402-
max_utils.close_summary_writer(writer)
372+
metric_logger.flush_metrics_and_cleanup()
403373

404374
if example_batch:
405375
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 19 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import os
2727
import sys
2828
import functools
29-
import queue
3029
from typing import Sequence
3130
from collections.abc import Callable
3231

@@ -60,13 +59,13 @@
6059
from MaxText.common_types import Array
6160
from MaxText.experimental.rl import grpo_input_pipeline
6261
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
62+
from MaxText.globals import EPS
6363
from MaxText.layers import models
6464
from MaxText.metric_logger import MetricLogger
6565
from MaxText.train import (
6666
validate_train_config,
6767
get_first_step,
6868
load_next_batch,
69-
record_scalar_metrics,
7069
save_checkpoint,
7170
check_example_batch,
7271
setup_mesh_and_model,
@@ -82,7 +81,6 @@
8281
# pylint: disable=too-many-positional-arguments
8382

8483
Transformer = models.Transformer
85-
EPS = 1e-8
8684

8785

8886
# -----------------------------------------------------------------------------
@@ -676,7 +674,6 @@ def train_loop(config, config_inference, recorder, state=None):
676674
"""Main Training loop."""
677675
(
678676
init_rng,
679-
writer,
680677
checkpoint_manager,
681678
state_mesh_shardings,
682679
model,
@@ -726,17 +723,6 @@ def train_loop(config, config_inference, recorder, state=None):
726723
donate_argnums_eval,
727724
) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_shardings, model, config)
728725

729-
# TODO: fix tflops calculations for grpo setting
730-
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
731-
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
732-
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
733-
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)
734-
735-
# Write train config params, num model params, and XLA flags to tensorboard
736-
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
737-
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
738-
maxtext_utils.add_config_to_summary_writer(config, writer)
739-
740726
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
741727
if config.compiled_trainstep_file != "":
742728
print("Loading the compiled function...", flush=True)
@@ -774,9 +760,6 @@ def train_loop(config, config_inference, recorder, state=None):
774760
donate_argnums=(0,),
775761
)
776762

777-
running_gcs_metrics = [] if config.gcs_metrics else None
778-
metrics = None
779-
780763
start_step = get_first_step(state) # this is the start_step for training
781764
prof = profiler.Profiler(config, offset_step=start_step)
782765
first_profiling_step = prof.start_initial_profile_step
@@ -785,20 +768,16 @@ def train_loop(config, config_inference, recorder, state=None):
785768
last_profiling_step = prof.finished_initial_profile_step
786769

787770
example_batch = None
788-
last_step_completion = datetime.datetime.now()
789-
790-
performance_metric_queue = None
791-
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
792-
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
793-
if config.report_heartbeat_metric_for_gcp_monitoring:
794-
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
795-
if config.report_performance_metric_for_gcp_monitoring:
796-
performance_metric_queue = queue.Queue()
797-
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)
798-
799-
metric_logger = MetricLogger(writer, config)
771+
800772
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
773+
774+
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
775+
776+
# Write train config params, num model params, and XLA flags to tensorboard
777+
metric_logger.write_setup_info_to_tensorboard(state.params)
778+
801779
for step in np.arange(start_step, config.steps):
780+
train_step_start_time = datetime.datetime.now()
802781
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
803782
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
804783
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
@@ -823,12 +802,6 @@ def train_loop(config, config_inference, recorder, state=None):
823802
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
824803
state, metrics = p_train_step(state, example_batch, rng)
825804

826-
step_time_delta = datetime.datetime.now() - last_step_completion
827-
last_step_completion = datetime.datetime.now()
828-
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
829-
if performance_metric_queue:
830-
performance_metric_queue.put(step_time_delta.total_seconds())
831-
832805
if checkpoint_manager is not None:
833806
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
834807
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
@@ -839,8 +812,6 @@ def train_loop(config, config_inference, recorder, state=None):
839812
checkpoint_manager.wait_until_finished()
840813
sys.exit()
841814

842-
metric_logger.write_metrics(running_gcs_metrics, metrics, step)
843-
844815
if config.dump_hlo and step == start_step:
845816
jax.block_until_ready(state) # Ensure compilation has finished.
846817
max_utils.upload_dump(
@@ -851,45 +822,25 @@ def train_loop(config, config_inference, recorder, state=None):
851822
all_host_upload=config.dump_hlo_upload_all,
852823
)
853824

825+
train_step_time_delta = datetime.datetime.now() - train_step_start_time
826+
metric_logger.record_train_metrics(metrics, step, train_step_time_delta)
827+
854828
if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
855829
assert eval_data_iterator
856-
cumulative_eval_metrics = {
857-
"scalar": {
858-
"eval/total_loss": 0.0,
859-
"eval/total_weights": 0.0,
860-
"eval/avg_loss": 0.0,
861-
"eval/moe_lb_loss": 0.0,
862-
}
863-
}
864-
eval_dpo_reward_accuracy = 0.0
830+
eval_step_start_time = datetime.datetime.now()
865831
eval_step_count = 0
866832
# pylint: disable=not-callable
867833
for eval_batch in eval_data_iterator:
868834
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
869835
break
870836
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
871837
eval_metrics = p_eval_step(state, eval_batch, rng)
872-
cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"])
873-
cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"])
874-
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"])
875-
eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only
838+
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
876839
max_logging.log(f"Completed eval step {eval_step_count}")
877840
eval_step_count += 1
878-
eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / (
879-
cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS
880-
)
881-
cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss
882-
cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
883-
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
884-
)
885-
if config.use_dpo:
886-
cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count
887-
metric_logger.write_metrics(running_gcs_metrics, cumulative_eval_metrics, step, is_training=False)
888-
max_logging.log(
889-
f"average loss after {step=}: {eval_step_count=}, {eval_loss=},"
890-
f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}"
891-
)
892-
if eval_loss <= config.target_eval_loss:
841+
eval_step_time_delta = datetime.datetime.now() - eval_step_start_time
842+
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count, eval_step_time_delta=eval_step_time_delta)
843+
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
893844
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
894845
prof.deactivate()
895846
break
@@ -911,8 +862,8 @@ def train_loop(config, config_inference, recorder, state=None):
911862
max_logging.log(f"Checkpoint already saved for step {int(state.step)-1}.")
912863

913864
checkpoint_manager.wait_until_finished()
914-
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
915-
max_utils.close_summary_writer(writer)
865+
metric_logger.flush_metrics_and_cleanup()
866+
916867
if example_batch:
917868
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
918869
compiled = p_train_step.lower(state, example_batch, rng).compile()

MaxText/globals.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import os.path
1818

19-
PKG_DIR = os.path.dirname(os.path.abspath(__file__))
19+
PKG_DIR = os.path.dirname(os.path.abspath(__file__)) # MaxText directory path
20+
EPS = 1e-8 # Epsilon to calculate loss
21+
DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 # Default checkpoint file size
2022

21-
__all__ = ["PKG_DIR"]
23+
__all__ = ["DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE", "EPS", "PKG_DIR"]

0 commit comments

Comments
 (0)