Skip to content

Commit 84d842c

Browse files
committed
Refactor profiler in trainers
1 parent 4d4b6b0 commit 84d842c

File tree

5 files changed

+29
-36
lines changed

5 files changed

+29
-36
lines changed

MaxText/elastic_train.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,6 @@ def train_loop(config, elastic_manager, recorder, state=None):
236236

237237
start_step = get_first_step(state) # this is the start_step for training
238238
prof = profiler.Profiler(config, offset_step=start_step)
239-
first_profiling_step = prof.start_initial_profile_step
240-
if config.profiler != "" and first_profiling_step >= config.steps:
241-
raise ValueError("Profiling requested but initial profiling step set past training final step")
242-
last_profiling_step = prof.finished_initial_profile_step
243239

244240
example_batch = None
245241
last_step_completion = datetime.datetime.now()
@@ -271,9 +267,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
271267
# the step is restored back to the latest snapshot when a slice is lost
272268
while step < config.steps:
273269
try:
274-
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
275-
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
276-
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
270+
prof.maybe_activate_profiler(step, state)
277271

278272
max_logging.log(f"{step=} {elastic_manager.elastic_down_event_count=} {elastic_manager.good_slice_count=}")
279273
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(elastic_manager.default_device):
@@ -310,8 +304,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
310304

311305
metric_logger.write_metrics(running_gcs_metrics, metrics, step)
312306

313-
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
314-
prof.deactivate(blocking_object=state)
307+
prof.maybe_deactivate_profiler(step, state)
315308

316309
elastic_manager.maybe_snapshot(
317310
step=step,

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,6 @@ def train_loop(config, config_inference, recorder, state=None):
779779

780780
start_step = get_first_step(state) # this is the start_step for training
781781
prof = profiler.Profiler(config, offset_step=start_step)
782-
first_profiling_step = prof.start_initial_profile_step
783-
if config.profiler != "" and first_profiling_step >= config.steps:
784-
raise ValueError("Profiling requested but initial profiling step set past training final step")
785-
last_profiling_step = prof.finished_initial_profile_step
786782

787783
example_batch = None
788784
last_step_completion = datetime.datetime.now()
@@ -799,9 +795,7 @@ def train_loop(config, config_inference, recorder, state=None):
799795
metric_logger = MetricLogger(writer, config)
800796
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
801797
for step in np.arange(start_step, config.steps):
802-
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
803-
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
804-
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
798+
prof.maybe_activate_profiler(step, state)
805799

806800
with jax.profiler.StepTraceAnnotation("train", step_num=step):
807801
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
@@ -894,8 +888,7 @@ def train_loop(config, config_inference, recorder, state=None):
894888
prof.deactivate()
895889
break
896890

897-
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
898-
prof.deactivate(blocking_object=state)
891+
prof.maybe_deactivate_profiler(step, state)
899892

900893
if step == start_step:
901894
max_utils.print_mem_stats("After params initialized")

MaxText/profiler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ def __init__(self, config, offset_step=0):
4040
self.profile_period = config.profile_periodically_period
4141
self.start_initial_profile_step = self._set_first_profiler_step(config.skip_first_n_steps_for_profiler, offset_step)
4242
self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps)
43+
if config.profiler != "" and self.start_initial_profile_step >= config.steps:
44+
raise ValueError("Profiling requested but initial profiling step set past training final step")
45+
46+
def maybe_activate_profiler(self, step, state):
47+
"""Conditionally activates the profiler based on the current step.
48+
This method checks if the current training step matches the step designated
49+
for starting an initial profile, or if it meets the criteria for
50+
activating a new periodic profile.
51+
"""
52+
if step == self.start_initial_profile_step or self.should_activate_periodic_profile(step):
53+
optional_postfix = f"step_{step}" if self.config.profile_periodically_period > 0 else ""
54+
self.activate(blocking_object=state, optional_postfix=optional_postfix)
4355

4456
def activate(self, blocking_object=None, optional_postfix=""):
4557
"""Start the profiler.
@@ -60,6 +72,15 @@ def activate(self, blocking_object=None, optional_postfix=""):
6072
elif self.mode == "xplane":
6173
jax.profiler.start_trace(self.output_path)
6274

75+
def maybe_deactivate_profiler(self, step, state):
76+
"""Conditionally deactivates the profiler based on the current step.
77+
This method checks if the current training step matches the step designated
78+
for finishing the initial profile, or if it meets the criteria for
79+
deactivating a periodic profile.
80+
"""
81+
if step == self.finished_initial_profile_step or self.should_deactivate_periodic_profile(step):
82+
self.deactivate(blocking_object=state)
83+
6384
def deactivate(self, blocking_object=None):
6485
"""End the profiler.
6586
The result is uploaded to the output bucket."""

MaxText/sft_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,6 @@ def train_loop(config, recorder, state=None):
137137

138138
start_step = get_first_step(state) # this is the start_step for training
139139
prof = profiler.Profiler(config, offset_step=start_step)
140-
first_profiling_step = prof.start_initial_profile_step
141-
if config.profiler != "" and first_profiling_step >= config.steps:
142-
raise ValueError("Profiling requested but initial profiling step set past training final step")
143-
last_profiling_step = prof.finished_initial_profile_step
144140

145141
example_batch = None
146142
last_step_completion = datetime.datetime.now()
@@ -157,9 +153,7 @@ def train_loop(config, recorder, state=None):
157153
metric_logger = MetricLogger(writer, config)
158154
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
159155
for step in np.arange(start_step, config.steps):
160-
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
161-
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
162-
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
156+
prof.maybe_activate_profiler(step, state)
163157

164158
with jax.profiler.StepTraceAnnotation("train", step_num=step):
165159
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
@@ -243,8 +237,7 @@ def train_loop(config, recorder, state=None):
243237
prof.deactivate()
244238
break
245239

246-
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
247-
prof.deactivate(blocking_object=state)
240+
prof.maybe_deactivate_profiler(step, state)
248241

249242
if step == start_step:
250243
max_utils.print_mem_stats("After params initialized")

MaxText/train.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -818,10 +818,6 @@ def train_loop(config, recorder, state=None):
818818

819819
start_step = get_first_step(state) # this is the start_step for training
820820
prof = profiler.Profiler(config, offset_step=start_step)
821-
first_profiling_step = prof.start_initial_profile_step
822-
if config.profiler != "" and first_profiling_step >= config.steps:
823-
raise ValueError("Profiling requested but initial profiling step set past training final step")
824-
last_profiling_step = prof.finished_initial_profile_step
825821

826822
example_batch = None
827823
last_step_completion = datetime.datetime.now()
@@ -838,9 +834,7 @@ def train_loop(config, recorder, state=None):
838834
metric_logger = MetricLogger(writer, config)
839835
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
840836
for step in np.arange(start_step, config.steps):
841-
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
842-
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
843-
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
837+
prof.maybe_activate_profiler(step, state)
844838

845839
with jax.profiler.StepTraceAnnotation("train", step_num=step):
846840
with maybe_record_goodput(recorder, GoodputEvent.DATA_LOADING):
@@ -930,8 +924,7 @@ def train_loop(config, recorder, state=None):
930924
prof.deactivate()
931925
break
932926

933-
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
934-
prof.deactivate(blocking_object=state)
927+
prof.maybe_deactivate_profiler(step, state)
935928

936929
if step == start_step:
937930
max_utils.print_mem_stats("After params initialized")

0 commit comments

Comments
 (0)