Skip to content

Commit 02b6b8d

Browse files
author
maxtext authors
committed
Merge pull request #1823 from AI-Hypercomputer:sujinesh/mcjax_long_running
PiperOrigin-RevId: 771226904
2 parents 244a071 + 1f2f59d commit 02b6b8d

File tree

6 files changed

+163
-15
lines changed

6 files changed

+163
-15
lines changed

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"checkpoint_storage_use_ocdbt": False,
3434
"checkpoint_storage_use_zarr3": False,
3535
"enable_pathways_goodput": True,
36+
"enable_goodput_recording": True,
3637
"enable_single_controller": True,
3738
"metrics_file": "metrics.txt",
3839
"goodput_upload_interval_seconds": 30,
@@ -44,6 +45,7 @@
4445
"async_checkpointing": True,
4546
"checkpoint_period": 100,
4647
"enable_checkpoint_cloud_logger": True,
48+
"enable_goodput_recording": True,
4749
}
4850

4951
# The set of tuning params required for short-running pathways jobs.
@@ -52,6 +54,7 @@
5254
"async_checkpointing": True,
5355
"checkpoint_period": 20,
5456
"enable_checkpoint_cloud_logger": True,
57+
"enable_goodput_recording": True,
5558
}
5659

5760

benchmarks/maxtext_xpk_runner.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,12 +477,19 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
477477
flags_to_add = wl_config.model.pathways_xla_flag_options[
478478
xla_flags.ADD_PROXY
479479
]
480-
proxy_flags.append(flags_to_add)
480+
flags_to_add_list = flags_to_add.strip().split()
481+
proxy_flags += flags_to_add_list
481482

482483
# Join the list of flags back into a single string, space-separated
483484
return ' '.join(proxy_flags)
484485

485486

487+
def _combine_flag_strings(base_flags: str, flags_to_add: str) -> str:
488+
"""Combines two flag strings and removes extraneous whitespace."""
489+
all_flags = base_flags.split() + flags_to_add.split()
490+
return ' '.join(all_flags)
491+
492+
486493
def _get_pathways_worker_flags(wl_config: WorkloadConfig):
487494
"""Get the pathways worker flags for the workload and removes any extras."""
488495
# Add in the xla flags alongside the worker flags from the pathways config.
@@ -499,7 +506,8 @@ def _get_pathways_worker_flags(wl_config: WorkloadConfig):
499506
flags_to_add = wl_config.model.pathways_xla_flag_options[
500507
xla_flags.ADD_WORKER
501508
]
502-
worker_flags += flags_to_add
509+
510+
worker_flags = _combine_flag_strings(worker_flags, flags_to_add)
503511

504512
# Join the list of flags back into a single string, space-separated
505513
return worker_flags
@@ -521,7 +529,7 @@ def _get_pathways_server_flags(wl_config: WorkloadConfig):
521529
flags_to_add = wl_config.model.pathways_xla_flag_options[
522530
xla_flags.ADD_SERVER
523531
]
524-
server_flags += flags_to_add
532+
server_flags = _combine_flag_strings(server_flags, flags_to_add)
525533

526534
# Join the list of flags back into a single string, space-separated
527535
return server_flags
@@ -581,22 +589,23 @@ def generate_xpk_workload_cmd(
581589
random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str)
582590
)
583591

584-
truncate_model_name = 12
585-
truncate_prefix = 5
586-
common_post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
592+
truncate_model_name = 10
593+
truncate_prefix = 3
594+
post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
587595
common_prefix = os.environ['USER']
588596
pw_prefix = "pw-"
589597

590598
if workload_name is None: # Generate name if not provided
591599
if is_pathways_enabled:
600+
post_fix = f"-{wl_config.num_slices}-{temp_post_fix}"
592601
name = (
593602
f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}"
594603
)
595604
else:
596605
name = (
597606
f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}"
598607
)
599-
name = f"{common_prefix[:truncate_prefix]}-{name}{common_post_fix}"
608+
name = f"{common_prefix[:truncate_prefix]}-{name}{post_fix}"
600609
else:
601610
name = workload_name # Use provided name
602611

@@ -629,7 +638,7 @@ def generate_xpk_workload_cmd(
629638
f'--docker-image={pw_config.runner_image}'
630639
)
631640
else:
632-
docker_image_flag = f'--base-docker-image="{wl_config.base_docker_image}"'
641+
docker_image_flag = f'--docker-image="{wl_config.base_docker_image}"'
633642

634643
upload_metrics_to_bq_cmd = ""
635644
if wl_config.generate_metrics_and_upload_to_big_query:

benchmarks/recipes/args_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ def _handle_delete(
3838
**kwargs: Optional keyword arguments, such as xpk_path
3939
"""
4040
xpk_path = kwargs.get("xpk_path", "xpk") # Default to "xpk" if not provided
41-
first_five_chars = user[:5]
41+
first_three_chars = user[:3]
4242
delete_command = (
4343
f"python3 {xpk_path}/xpk.py workload delete "
4444
f"--project={cluster_config.project} --cluster={cluster_config.cluster_name}"
45-
f" --filter-by-job={first_five_chars} --zone={cluster_config.zone}"
45+
f" --filter-by-job={first_three_chars} --zone={cluster_config.zone}"
4646
)
4747
print(
48-
f"Deleting workloads starting with: {first_five_chars} using command:"
48+
f"Deleting workloads starting with: {first_three_chars} using command:"
4949
f" {delete_command}"
5050
)
5151
os.system(delete_command)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import datetime
18+
import sys
19+
import os
20+
21+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
22+
sys.path.append(parent_dir)
23+
24+
import recipes.args_helper as helper
25+
import maxtext_trillium_model_configs as model_configs
26+
import maxtext_xpk_runner as mxr
27+
from xpk_configs import XpkClusterConfig
28+
29+
# Cluster Params
30+
CLUSTER = "v6e-256-cluster"
31+
PROJECT = "tpu-prod-env-cluster"
32+
ZONE = "us-east5-b"
33+
REGION = "us-east5"
34+
COUNTRY = "us"
35+
DEVICE_TYPE = "v6e-256"
36+
37+
# Other parameters (MUST BE SET BY USER)
38+
XPK_PATH = os.path.join("~", "xpk")
39+
USER = os.environ["USER"]
40+
BASE_OUTPUT_DIRECTORY = (
41+
f"gs://{USER}-{PROJECT}-{COUNTRY}/mcjax_long_run/"
42+
)
43+
# Generate your own runner image from MaxText repo.
44+
RUNNER = f"gcr.io/{PROJECT}/{USER}_latest"
45+
46+
MAX_RESTARTS = 10_000
47+
BENCHMARK_STEPS=10_000_000
48+
49+
50+
def main() -> int:
51+
# V6e cluster config
52+
cluster_config = XpkClusterConfig(
53+
cluster_name=CLUSTER,
54+
project=PROJECT,
55+
zone=ZONE,
56+
device_type=DEVICE_TYPE,
57+
)
58+
59+
# Handle command line arguments using args_helper
60+
should_continue = helper.handle_cmd_args(
61+
cluster_config, helper.DELETE, xpk_path=XPK_PATH
62+
)
63+
64+
if not should_continue:
65+
return 0
66+
67+
model_list = [
68+
# model_configs.llama3_1_70b_8192_pw_lr_real_data,
69+
# model_configs.llama3_1_8b_8192,
70+
model_configs.llama3_1_70b_8192_iter_synth_data_and_checkpointing,
71+
# model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds,
72+
]
73+
num_slices_list = [
74+
2
75+
]
76+
77+
xpk_workload_cmds = []
78+
xpk_workload_names = []
79+
80+
for model in model_list:
81+
# Run workloads on the below clusters
82+
for cluster_config in [
83+
cluster_config,
84+
]:
85+
86+
# Make modifications to the model config here to add in any additional
87+
# flags or changes to the model config.
88+
model.tuning_params["use_vertex_tensorboard"] = True
89+
model.tuning_params["vertex_tensorboard_project"] = PROJECT
90+
model.tuning_params["vertex_tensorboard_region"] = REGION
91+
92+
# Run workloads in the following slice configurations
93+
for num_slices in num_slices_list:
94+
wl_config = mxr.WorkloadConfig(
95+
model=model,
96+
num_slices=num_slices,
97+
device_type=cluster_config.device_type,
98+
base_output_directory=BASE_OUTPUT_DIRECTORY,
99+
max_restarts=MAX_RESTARTS,
100+
libtpu_type=mxr.LibTpuType.MAXTEXT,
101+
libtpu_nightly_version="",
102+
base_docker_image=RUNNER,
103+
xpk_path=XPK_PATH,
104+
num_steps=BENCHMARK_STEPS,
105+
priority="medium",
106+
)
107+
command, name = mxr.generate_xpk_workload_cmd(
108+
cluster_config=cluster_config, wl_config=wl_config
109+
)
110+
111+
print(f"Name of the workload is: {name} \n")
112+
xpk_workload_names.append(name)
113+
114+
print(f"XPK command to be used is: {command} \n")
115+
xpk_workload_cmds.append(command)
116+
117+
for xpk_workload_name, xpk_workload_cmd in zip(
118+
xpk_workload_names, xpk_workload_cmds
119+
):
120+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
121+
print(
122+
f"[{timestamp}] Running workload: {xpk_workload_name} with command:"
123+
f" {xpk_workload_cmd}"
124+
)
125+
return_code = mxr.run_command_with_updates(
126+
xpk_workload_cmd, xpk_workload_name
127+
)
128+
if return_code != 0:
129+
print(f"Unable to run xpk workload: {xpk_workload_name}")
130+
131+
132+
if __name__ == "__main__":
133+
main()

benchmarks/recipes/pw_long_running_recipe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import datetime
1818
import sys
1919
import os
20-
import args_helper as helper
2120

2221
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2322
sys.path.append(parent_dir)
2423

24+
import recipes.args_helper as helper
2525
import maxtext_trillium_model_configs as model_configs
2626
import maxtext_xpk_runner as mxr
2727
from xpk_configs import XpkClusterConfig
@@ -42,7 +42,7 @@
4242
XPK_PATH = os.path.join("~", "xpk") # We're running this script from the maxtext directory
4343
USER = os.environ["USER"]
4444
BASE_OUTPUT_DIRECTORY = (
45-
f"gs://{USER}-{PROJECT}-{COUNTRY}/pw_mcjax_benchmarking/"
45+
f"gs://{USER}-{PROJECT}-{COUNTRY}/pw_long_run/"
4646
)
4747

4848
MAX_RESTARTS = 10_000
@@ -70,8 +70,10 @@ def main() -> int:
7070
# model_configs.llama3_1_70b_8192_pw_lr_real_data,
7171
# model_configs.llama3_1_8b_8192,
7272
# model_configs.llama3_1_70b_8192_iter_synth_data_and_checkpointing,
73-
model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds,
73+
# model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds,
74+
model_configs.llama3_1_70b_8192_iter_synthetic,
7475
]
76+
7577
pathways_config = mxr.PathwaysConfig(
7678
server_image=SERVER_IMAGE,
7779
proxy_server_image=PROXY_IMAGE,
@@ -104,6 +106,7 @@ def main() -> int:
104106
model.tuning_params["use_vertex_tensorboard"] = True
105107
model.tuning_params["vertex_tensorboard_project"] = PROJECT
106108
model.tuning_params["vertex_tensorboard_region"] = REGION
109+
model.tuning_params["profiler"] = "xplane"
107110

108111
# Run workloads in the following slice configurations
109112
for num_slices in num_slices_list:

benchmarks/upload_metrics_to_bq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def update_config_with_tuning_params(base_config: omegaconf.DictConfig,
239239

240240
def main(argv: Sequence[str]) -> None:
241241
is_pathways = os.environ.get('JAX_PLATFORMS', '') == 'proxy'
242-
is_mcjax_0th_worker = int(os.environ['TPU_WORKER_ID']) == 0
242+
is_mcjax_0th_worker = int(os.environ.get('TPU_WORKER_ID', -1)) == 0
243243

244244
# Only write once for McJAX. Pathways is single controller,
245245
# so only can write once.

0 commit comments

Comments
 (0)