Skip to content

Refactor: Recording and logging training and evaluation metrics in all trainers #1815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 12 additions & 42 deletions MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import os
import sys
import time
import queue

from absl import app

Expand All @@ -68,13 +67,11 @@
from MaxText import max_logging
from MaxText import profiler
from MaxText import pyconfig
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
from MaxText.metric_logger import MetricLogger
from MaxText.train import check_example_batch
from MaxText.train import get_first_step
from MaxText.train import load_next_batch
from MaxText.train import record_scalar_metrics
from MaxText.train import save_checkpoint
from MaxText.train import setup_mesh_and_model
from MaxText.train import setup_train_loop
Expand Down Expand Up @@ -120,7 +117,7 @@ def elastic_handler(
checkpoint_manager.close()

with jax.default_device(elastic_manager.default_device):
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
init_rng, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(
config, elastic_manager.good_devices
)
with mesh:
Expand Down Expand Up @@ -171,7 +168,7 @@ def elastic_handler(
)

example_batch = None
metric_logger = MetricLogger(writer, config)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

jax.block_until_ready(state)

Expand All @@ -186,15 +183,13 @@ def elastic_handler(
example_batch,
learning_rate_schedule,
metric_logger,
writer,
)


def train_loop(config, elastic_manager, recorder, state=None):
"""Main Training loop."""
(
init_rng,
writer,
checkpoint_manager,
state_mesh_shardings,
model,
Expand All @@ -214,25 +209,13 @@ def train_loop(config, elastic_manager, recorder, state=None):
donate_argnums_train,
) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_shardings, model, config)

num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
maxtext_utils.add_config_to_summary_writer(config, writer)

p_train_step = jax.jit(
functional_train,
in_shardings=in_shard_train,
out_shardings=out_shard_train,
static_argnums=static_argnums_train,
donate_argnums=donate_argnums_train,
)
running_gcs_metrics = [] if config.gcs_metrics else None
metrics = None

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
Expand All @@ -242,18 +225,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()

performance_metric_queue = None
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
if config.report_heartbeat_metric_for_gcp_monitoring:
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
if config.report_performance_metric_for_gcp_monitoring:
performance_metric_queue = queue.Queue()
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)

metric_logger = MetricLogger(writer, config)
step = start_step

elastic_manager.maybe_snapshot(
Expand All @@ -267,10 +238,17 @@ def train_loop(config, elastic_manager, recorder, state=None):
)

input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)

# Using while loop instead of a for loop because with elasticity
# the step is restored back to the latest snapshot when a slice is lost
while step < config.steps:
try:
train_step_start_time = datetime.datetime.now()
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
Expand All @@ -292,12 +270,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
state, metrics = p_train_step(state, example_batch, nextrng)

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
if performance_metric_queue:
performance_metric_queue.put(step_time_delta.total_seconds())

if checkpoint_manager is not None:
state_to_save = state
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
Expand All @@ -308,7 +280,8 @@ def train_loop(config, elastic_manager, recorder, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

metric_logger.write_metrics(running_gcs_metrics, metrics, step)
train_step_time_delta = datetime.datetime.now() - train_step_start_time
metric_logger.record_train_metrics(metrics, step, train_step_time_delta)

if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)
Expand Down Expand Up @@ -347,7 +320,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
example_batch,
learning_rate_schedule,
metric_logger,
writer,
) = ret

if step == start_step:
Expand Down Expand Up @@ -377,7 +349,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
example_batch,
learning_rate_schedule,
metric_logger,
writer,
) = ret

if checkpoint_manager is not None:
Expand All @@ -398,8 +369,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
max_logging.log(f"Checkpoint is already saved for step {int(state.step)-1}.")

checkpoint_manager.wait_until_finished()
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
max_utils.close_summary_writer(writer)
metric_logger.flush_metrics_and_cleanup()

if example_batch:
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down
87 changes: 19 additions & 68 deletions MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import os
import sys
import functools
import queue
from typing import Sequence
from collections.abc import Callable

Expand Down Expand Up @@ -60,13 +59,13 @@
from MaxText.common_types import Array
from MaxText.experimental.rl import grpo_input_pipeline
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
from MaxText.globals import EPS
from MaxText.layers import models
from MaxText.metric_logger import MetricLogger
from MaxText.train import (
validate_train_config,
get_first_step,
load_next_batch,
record_scalar_metrics,
save_checkpoint,
check_example_batch,
setup_mesh_and_model,
Expand All @@ -82,7 +81,6 @@
# pylint: disable=too-many-positional-arguments

Transformer = models.Transformer
EPS = 1e-8


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -676,7 +674,6 @@ def train_loop(config, config_inference, recorder, state=None):
"""Main Training loop."""
(
init_rng,
writer,
checkpoint_manager,
state_mesh_shardings,
model,
Expand Down Expand Up @@ -726,17 +723,6 @@ def train_loop(config, config_inference, recorder, state=None):
donate_argnums_eval,
) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_shardings, model, config)

# TODO: fix tflops calculations for grpo setting
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
maxtext_utils.add_config_to_summary_writer(config, writer)

# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
if config.compiled_trainstep_file != "":
print("Loading the compiled function...", flush=True)
Expand Down Expand Up @@ -774,9 +760,6 @@ def train_loop(config, config_inference, recorder, state=None):
donate_argnums=(0,),
)

running_gcs_metrics = [] if config.gcs_metrics else None
metrics = None

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
first_profiling_step = prof.start_initial_profile_step
Expand All @@ -785,20 +768,16 @@ def train_loop(config, config_inference, recorder, state=None):
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()

performance_metric_queue = None
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
if config.report_heartbeat_metric_for_gcp_monitoring:
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
if config.report_performance_metric_for_gcp_monitoring:
performance_metric_queue = queue.Queue()
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)

metric_logger = MetricLogger(writer, config)

input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)

for step in np.arange(start_step, config.steps):
train_step_start_time = datetime.datetime.now()
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
Expand All @@ -823,12 +802,6 @@ def train_loop(config, config_inference, recorder, state=None):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, rng)

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
if performance_metric_queue:
performance_metric_queue.put(step_time_delta.total_seconds())

if checkpoint_manager is not None:
state_to_save = state if not config.use_dpo else _split_grpo_state(state)[0]
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
Expand All @@ -839,8 +812,6 @@ def train_loop(config, config_inference, recorder, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

metric_logger.write_metrics(running_gcs_metrics, metrics, step)

if config.dump_hlo and step == start_step:
jax.block_until_ready(state) # Ensure compilation has finished.
max_utils.upload_dump(
Expand All @@ -851,45 +822,25 @@ def train_loop(config, config_inference, recorder, state=None):
all_host_upload=config.dump_hlo_upload_all,
)

train_step_time_delta = datetime.datetime.now() - train_step_start_time
metric_logger.record_train_metrics(metrics, step, train_step_time_delta)

if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
cumulative_eval_metrics = {
"scalar": {
"eval/total_loss": 0.0,
"eval/total_weights": 0.0,
"eval/avg_loss": 0.0,
"eval/moe_lb_loss": 0.0,
}
}
eval_dpo_reward_accuracy = 0.0
eval_step_start_time = datetime.datetime.now()
eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"])
cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"])
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"])
eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / (
cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS
)
cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss
cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = (
cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count
)
if config.use_dpo:
cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count
metric_logger.write_metrics(running_gcs_metrics, cumulative_eval_metrics, step, is_training=False)
max_logging.log(
f"average loss after {step=}: {eval_step_count=}, {eval_loss=},"
f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}"
)
if eval_loss <= config.target_eval_loss:
eval_step_time_delta = datetime.datetime.now() - eval_step_start_time
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count, eval_step_time_delta=eval_step_time_delta)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
prof.deactivate()
break
Expand All @@ -911,8 +862,8 @@ def train_loop(config, config_inference, recorder, state=None):
max_logging.log(f"Checkpoint already saved for step {int(state.step)-1}.")

checkpoint_manager.wait_until_finished()
metric_logger.write_metrics(running_gcs_metrics, metrics, config.steps - 1) # final step metrics
max_utils.close_summary_writer(writer)
metric_logger.flush_metrics_and_cleanup()

if example_batch:
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
compiled = p_train_step.lower(state, example_batch, rng).compile()
Expand Down
6 changes: 4 additions & 2 deletions MaxText/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import os.path

PKG_DIR = os.path.dirname(os.path.abspath(__file__))
PKG_DIR = os.path.dirname(os.path.abspath(__file__)) # MaxText directory path
EPS = 1e-8 # Epsilon to calculate loss
DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 # Default checkpoint file size

__all__ = ["PKG_DIR"]
__all__ = ["DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE", "EPS", "PKG_DIR"]
Loading
Loading