26
26
import os
27
27
import sys
28
28
import functools
29
- import queue
30
29
from typing import Sequence
31
30
from collections .abc import Callable
32
31
60
59
from MaxText .common_types import Array
61
60
from MaxText .experimental .rl import grpo_input_pipeline
62
61
from MaxText .gcp_workload_monitor import GCPWorkloadMonitor
62
+ from MaxText .globals import EPS
63
63
from MaxText .layers import models
64
64
from MaxText .metric_logger import MetricLogger
65
65
from MaxText .train import (
66
66
validate_train_config ,
67
67
get_first_step ,
68
68
load_next_batch ,
69
- record_scalar_metrics ,
70
69
save_checkpoint ,
71
70
check_example_batch ,
72
71
setup_mesh_and_model ,
82
81
# pylint: disable=too-many-positional-arguments
83
82
84
83
Transformer = models .Transformer
85
- EPS = 1e-8
86
84
87
85
88
86
# -----------------------------------------------------------------------------
@@ -676,7 +674,6 @@ def train_loop(config, config_inference, recorder, state=None):
676
674
"""Main Training loop."""
677
675
(
678
676
init_rng ,
679
- writer ,
680
677
checkpoint_manager ,
681
678
state_mesh_shardings ,
682
679
model ,
@@ -726,17 +723,6 @@ def train_loop(config, config_inference, recorder, state=None):
726
723
donate_argnums_eval ,
727
724
) = maxtext_utils .get_functional_eval_with_signature (eval_step , mesh , state_mesh_shardings , model , config )
728
725
729
- # TODO: fix tflops calculations for grpo setting
730
- num_model_parameters = max_utils .calculate_num_params_from_pytree (state .params )
731
- max_logging .log (f"number parameters: { num_model_parameters / 1e9 :.3f} billion" )
732
- per_device_tflops , _ , _ = maxtext_utils .calculate_tflops_training_per_device (config )
733
- per_device_tokens = maxtext_utils .calculate_tokens_training_per_device (config )
734
-
735
- # Write train config params, num model params, and XLA flags to tensorboard
736
- max_utils .add_text_to_summary_writer ("num_model_parameters" , str (num_model_parameters ), writer )
737
- max_utils .add_text_to_summary_writer ("libtpu_init_args" , os .environ ["LIBTPU_INIT_ARGS" ], writer )
738
- maxtext_utils .add_config_to_summary_writer (config , writer )
739
-
740
726
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
741
727
if config .compiled_trainstep_file != "" :
742
728
print ("Loading the compiled function..." , flush = True )
@@ -774,9 +760,6 @@ def train_loop(config, config_inference, recorder, state=None):
774
760
donate_argnums = (0 ,),
775
761
)
776
762
777
- running_gcs_metrics = [] if config .gcs_metrics else None
778
- metrics = None
779
-
780
763
start_step = get_first_step (state ) # this is the start_step for training
781
764
prof = profiler .Profiler (config , offset_step = start_step )
782
765
first_profiling_step = prof .start_initial_profile_step
@@ -785,20 +768,16 @@ def train_loop(config, config_inference, recorder, state=None):
785
768
last_profiling_step = prof .finished_initial_profile_step
786
769
787
770
example_batch = None
788
- last_step_completion = datetime .datetime .now ()
789
-
790
- performance_metric_queue = None
791
- if config .report_heartbeat_metric_for_gcp_monitoring or config .report_performance_metric_for_gcp_monitoring :
792
- gcp_workload_monitor = GCPWorkloadMonitor (config .run_name )
793
- if config .report_heartbeat_metric_for_gcp_monitoring :
794
- gcp_workload_monitor .start_heartbeat_reporting_thread (config .heartbeat_reporting_interval_in_seconds )
795
- if config .report_performance_metric_for_gcp_monitoring :
796
- performance_metric_queue = queue .Queue ()
797
- gcp_workload_monitor .start_performance_reporting_thread (performance_metric_queue )
798
-
799
- metric_logger = MetricLogger (writer , config )
771
+
800
772
input_data_shardings = maxtext_utils .get_input_data_sharding (config , mesh )
773
+
774
+ metric_logger = MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
775
+
776
+ # Write train config params, num model params, and XLA flags to tensorboard
777
+ metric_logger .write_setup_info_to_tensorboard (state .params )
778
+
801
779
for step in np .arange (start_step , config .steps ):
780
+ train_step_start_time = datetime .datetime .now ()
802
781
if step == first_profiling_step or prof .should_activate_periodic_profile (step ):
803
782
optional_postfix = f"step_{ step } " if config .profile_periodically_period > 0 else ""
804
783
prof .activate (blocking_object = state , optional_postfix = optional_postfix )
@@ -823,12 +802,6 @@ def train_loop(config, config_inference, recorder, state=None):
823
802
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
824
803
state , metrics = p_train_step (state , example_batch , rng )
825
804
826
- step_time_delta = datetime .datetime .now () - last_step_completion
827
- last_step_completion = datetime .datetime .now ()
828
- record_scalar_metrics (metrics , step_time_delta , per_device_tflops , learning_rate_schedule (step ), per_device_tokens )
829
- if performance_metric_queue :
830
- performance_metric_queue .put (step_time_delta .total_seconds ())
831
-
832
805
if checkpoint_manager is not None :
833
806
state_to_save = state if not config .use_dpo else _split_grpo_state (state )[0 ]
834
807
if save_checkpoint (checkpoint_manager , int (step ), state_to_save , config .dataset_type , data_iterator , config ):
@@ -839,8 +812,6 @@ def train_loop(config, config_inference, recorder, state=None):
839
812
checkpoint_manager .wait_until_finished ()
840
813
sys .exit ()
841
814
842
- metric_logger .write_metrics (running_gcs_metrics , metrics , step )
843
-
844
815
if config .dump_hlo and step == start_step :
845
816
jax .block_until_ready (state ) # Ensure compilation has finished.
846
817
max_utils .upload_dump (
@@ -851,45 +822,25 @@ def train_loop(config, config_inference, recorder, state=None):
851
822
all_host_upload = config .dump_hlo_upload_all ,
852
823
)
853
824
825
+ train_step_time_delta = datetime .datetime .now () - train_step_start_time
826
+ metric_logger .record_train_metrics (metrics , step , train_step_time_delta )
827
+
854
828
if config .eval_interval > 0 and step > start_step and (step + 1 ) % config .eval_interval == 0 :
855
829
assert eval_data_iterator
856
- cumulative_eval_metrics = {
857
- "scalar" : {
858
- "eval/total_loss" : 0.0 ,
859
- "eval/total_weights" : 0.0 ,
860
- "eval/avg_loss" : 0.0 ,
861
- "eval/moe_lb_loss" : 0.0 ,
862
- }
863
- }
864
- eval_dpo_reward_accuracy = 0.0
830
+ eval_step_start_time = datetime .datetime .now ()
865
831
eval_step_count = 0
866
832
# pylint: disable=not-callable
867
833
for eval_batch in eval_data_iterator :
868
834
if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
869
835
break
870
836
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
871
837
eval_metrics = p_eval_step (state , eval_batch , rng )
872
- cumulative_eval_metrics ["scalar" ]["eval/total_loss" ] += float (eval_metrics ["scalar" ]["evaluation/total_loss" ])
873
- cumulative_eval_metrics ["scalar" ]["eval/total_weights" ] += float (eval_metrics ["scalar" ]["evaluation/total_weights" ])
874
- cumulative_eval_metrics ["scalar" ]["eval/moe_lb_loss" ] += float (eval_metrics ["scalar" ]["evaluation/moe_lb_loss" ])
875
- eval_dpo_reward_accuracy += float (eval_metrics ["scalar" ].get ("evaluation/dpo_reward_accuracy" , 0.0 )) # for dpo only
838
+ metric_logger .record_eval_metrics (step , metrics = eval_metrics )
876
839
max_logging .log (f"Completed eval step { eval_step_count } " )
877
840
eval_step_count += 1
878
- eval_loss = cumulative_eval_metrics ["scalar" ]["eval/total_loss" ] / (
879
- cumulative_eval_metrics ["scalar" ]["eval/total_weights" ] + EPS
880
- )
881
- cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] = eval_loss
882
- cumulative_eval_metrics ["scalar" ]["eval/avg_moe_lb_loss" ] = (
883
- cumulative_eval_metrics ["scalar" ]["eval/moe_lb_loss" ] / eval_step_count
884
- )
885
- if config .use_dpo :
886
- cumulative_eval_metrics ["scalar" ]["eval/dpo_reward_accuracy" ] = eval_dpo_reward_accuracy / eval_step_count
887
- metric_logger .write_metrics (running_gcs_metrics , cumulative_eval_metrics , step , is_training = False )
888
- max_logging .log (
889
- f"average loss after { step = } : { eval_step_count = } , { eval_loss = } ,"
890
- f" total_weights={ cumulative_eval_metrics ['scalar' ]['eval/total_weights' ]} "
891
- )
892
- if eval_loss <= config .target_eval_loss :
841
+ eval_step_time_delta = datetime .datetime .now () - eval_step_start_time
842
+ metric_logger .record_eval_metrics (step , eval_step_count = eval_step_count , eval_step_time_delta = eval_step_time_delta )
843
+ if metric_logger .cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] <= config .target_eval_loss :
893
844
max_logging .log (f"Early stop and exit loop after reaching { config .target_eval_loss = } " )
894
845
prof .deactivate ()
895
846
break
@@ -911,8 +862,8 @@ def train_loop(config, config_inference, recorder, state=None):
911
862
max_logging .log (f"Checkpoint already saved for step { int (state .step )- 1 } ." )
912
863
913
864
checkpoint_manager .wait_until_finished ()
914
- metric_logger .write_metrics ( running_gcs_metrics , metrics , config . steps - 1 ) # final step metrics
915
- max_utils . close_summary_writer ( writer )
865
+ metric_logger .flush_metrics_and_cleanup ()
866
+
916
867
if example_batch :
917
868
with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
918
869
compiled = p_train_step .lower (state , example_batch , rng ).compile ()
0 commit comments