60
60
from MaxText .common_types import Array
61
61
from MaxText .experimental .rl import grpo_input_pipeline
62
62
from MaxText .gcp_workload_monitor import GCPWorkloadMonitor
63
+ from MaxText .globals import EPS
63
64
from MaxText .layers import models
64
65
from MaxText .metric_logger import MetricLogger
65
66
from MaxText .train import (
66
67
validate_train_config ,
67
68
get_first_step ,
68
69
load_next_batch ,
69
- record_scalar_metrics ,
70
70
save_checkpoint ,
71
71
check_example_batch ,
72
72
setup_mesh_and_model ,
82
82
# pylint: disable=too-many-positional-arguments
83
83
84
84
Transformer = models .Transformer
85
- EPS = 1e-8
86
85
87
86
88
87
# -----------------------------------------------------------------------------
@@ -676,7 +675,6 @@ def train_loop(config, config_inference, recorder, state=None):
676
675
"""Main Training loop."""
677
676
(
678
677
init_rng ,
679
- writer ,
680
678
checkpoint_manager ,
681
679
state_mesh_shardings ,
682
680
model ,
@@ -727,15 +725,10 @@ def train_loop(config, config_inference, recorder, state=None):
727
725
) = maxtext_utils .get_functional_eval_with_signature (eval_step , mesh , state_mesh_shardings , model , config )
728
726
729
727
# TODO: fix tflops calculations for grpo setting
730
- num_model_parameters = max_utils .calculate_num_params_from_pytree (state .params )
731
- max_logging .log (f"number parameters: { num_model_parameters / 1e9 :.3f} billion" )
732
- per_device_tflops , _ , _ = maxtext_utils .calculate_tflops_training_per_device (config )
733
- per_device_tokens = maxtext_utils .calculate_tokens_training_per_device (config )
728
+ metric_logger = MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
734
729
735
730
# Write train config params, num model params, and XLA flags to tensorboard
736
- max_utils .add_text_to_summary_writer ("num_model_parameters" , str (num_model_parameters ), writer )
737
- max_utils .add_text_to_summary_writer ("libtpu_init_args" , os .environ ["LIBTPU_INIT_ARGS" ], writer )
738
- maxtext_utils .add_config_to_summary_writer (config , writer )
731
+ metric_logger .write_setup_info_to_tensorboard (state .params )
739
732
740
733
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
741
734
if config .compiled_trainstep_file != "" :
@@ -774,9 +767,6 @@ def train_loop(config, config_inference, recorder, state=None):
774
767
donate_argnums = (0 ,),
775
768
)
776
769
777
- running_gcs_metrics = [] if config .gcs_metrics else None
778
- metrics = None
779
-
780
770
start_step = get_first_step (state ) # this is the start_step for training
781
771
prof = profiler .Profiler (config , offset_step = start_step )
782
772
first_profiling_step = prof .start_initial_profile_step
@@ -785,19 +775,14 @@ def train_loop(config, config_inference, recorder, state=None):
785
775
last_profiling_step = prof .finished_initial_profile_step
786
776
787
777
example_batch = None
788
- last_step_completion = datetime .datetime .now ()
789
-
790
- performance_metric_queue = None
791
- if config .report_heartbeat_metric_for_gcp_monitoring or config .report_performance_metric_for_gcp_monitoring :
792
- gcp_workload_monitor = GCPWorkloadMonitor (config .run_name )
793
- if config .report_heartbeat_metric_for_gcp_monitoring :
794
- gcp_workload_monitor .start_heartbeat_reporting_thread (config .heartbeat_reporting_interval_in_seconds )
795
- if config .report_performance_metric_for_gcp_monitoring :
796
- performance_metric_queue = queue .Queue ()
797
- gcp_workload_monitor .start_performance_reporting_thread (performance_metric_queue )
798
-
799
- metric_logger = MetricLogger (writer , config )
778
+
800
779
input_data_shardings = maxtext_utils .get_input_data_sharding (config , mesh )
780
+
781
+ metric_logger = MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
782
+
783
+ # Write train config params, num model params, and XLA flags to tensorboard
784
+ metric_logger .write_setup_info_to_tensorboard (state .params )
785
+
801
786
for step in np .arange (start_step , config .steps ):
802
787
if step == first_profiling_step or prof .should_activate_periodic_profile (step ):
803
788
optional_postfix = f"step_{ step } " if config .profile_periodically_period > 0 else ""
@@ -823,11 +808,7 @@ def train_loop(config, config_inference, recorder, state=None):
823
808
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
824
809
state , metrics = p_train_step (state , example_batch , rng )
825
810
826
- step_time_delta = datetime .datetime .now () - last_step_completion
827
- last_step_completion = datetime .datetime .now ()
828
- record_scalar_metrics (metrics , step_time_delta , per_device_tflops , learning_rate_schedule (step ), per_device_tokens )
829
- if performance_metric_queue :
830
- performance_metric_queue .put (step_time_delta .total_seconds ())
811
+ metric_logger .record_train_metrics (metrics , step )
831
812
832
813
if checkpoint_manager is not None :
833
814
state_to_save = state if not config .use_dpo else _split_grpo_state (state )[0 ]
@@ -839,8 +820,6 @@ def train_loop(config, config_inference, recorder, state=None):
839
820
checkpoint_manager .wait_until_finished ()
840
821
sys .exit ()
841
822
842
- metric_logger .write_metrics (running_gcs_metrics , metrics , step )
843
-
844
823
if config .dump_hlo and step == start_step :
845
824
jax .block_until_ready (state ) # Ensure compilation has finished.
846
825
max_utils .upload_dump (
@@ -853,43 +832,18 @@ def train_loop(config, config_inference, recorder, state=None):
853
832
854
833
if config .eval_interval > 0 and step > start_step and (step + 1 ) % config .eval_interval == 0 :
855
834
assert eval_data_iterator
856
- cumulative_eval_metrics = {
857
- "scalar" : {
858
- "eval/total_loss" : 0.0 ,
859
- "eval/total_weights" : 0.0 ,
860
- "eval/avg_loss" : 0.0 ,
861
- "eval/moe_lb_loss" : 0.0 ,
862
- }
863
- }
864
- eval_dpo_reward_accuracy = 0.0
865
835
eval_step_count = 0
866
836
# pylint: disable=not-callable
867
837
for eval_batch in eval_data_iterator :
868
838
if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
869
839
break
870
840
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
871
841
eval_metrics = p_eval_step (state , eval_batch , rng )
872
- cumulative_eval_metrics ["scalar" ]["eval/total_loss" ] += float (eval_metrics ["scalar" ]["evaluation/total_loss" ])
873
- cumulative_eval_metrics ["scalar" ]["eval/total_weights" ] += float (eval_metrics ["scalar" ]["evaluation/total_weights" ])
874
- cumulative_eval_metrics ["scalar" ]["eval/moe_lb_loss" ] += float (eval_metrics ["scalar" ]["evaluation/moe_lb_loss" ])
875
- eval_dpo_reward_accuracy += float (eval_metrics ["scalar" ].get ("evaluation/dpo_reward_accuracy" , 0.0 )) # for dpo only
842
+ metric_logger .record_eval_metrics (metrics = eval_metrics , eval_step_count = None , step = step )
876
843
max_logging .log (f"Completed eval step { eval_step_count } " )
877
844
eval_step_count += 1
878
- eval_loss = cumulative_eval_metrics ["scalar" ]["eval/total_loss" ] / (
879
- cumulative_eval_metrics ["scalar" ]["eval/total_weights" ] + EPS
880
- )
881
- cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] = eval_loss
882
- cumulative_eval_metrics ["scalar" ]["eval/avg_moe_lb_loss" ] = (
883
- cumulative_eval_metrics ["scalar" ]["eval/moe_lb_loss" ] / eval_step_count
884
- )
885
- if config .use_dpo :
886
- cumulative_eval_metrics ["scalar" ]["eval/dpo_reward_accuracy" ] = eval_dpo_reward_accuracy / eval_step_count
887
- metric_logger .write_metrics (running_gcs_metrics , cumulative_eval_metrics , step , is_training = False )
888
- max_logging .log (
889
- f"average loss after { step = } : { eval_step_count = } , { eval_loss = } ,"
890
- f" total_weights={ cumulative_eval_metrics ['scalar' ]['eval/total_weights' ]} "
891
- )
892
- if eval_loss <= config .target_eval_loss :
845
+ metric_logger .record_eval_metrics (metrics = None , eval_step_count = eval_step_count , step = step )
846
+ if metric_logger .cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] <= config .target_eval_loss :
893
847
max_logging .log (f"Early stop and exit loop after reaching { config .target_eval_loss = } " )
894
848
prof .deactivate ()
895
849
break
@@ -911,8 +865,8 @@ def train_loop(config, config_inference, recorder, state=None):
911
865
max_logging .log (f"Checkpoint already saved for step { int (state .step )- 1 } ." )
912
866
913
867
checkpoint_manager .wait_until_finished ()
914
- metric_logger .write_metrics ( running_gcs_metrics , metrics , config . steps - 1 ) # final step metrics
915
- max_utils . close_summary_writer ( writer )
868
+ metric_logger .flush_metrics_and_cleanup ()
869
+
916
870
if example_batch :
917
871
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
918
872
compiled = p_train_step .lower (state , example_batch , rng ).compile ()
0 commit comments