Skip to content

Commit 4d4b6b0

Browse files
author
maxtext authors
committed
Merge pull request #1708 from AI-Hypercomputer:sujinesh/support_headless_mode_recipe
PiperOrigin-RevId: 771265392
2 parents 4b099d9 + c5faff6 commit 4d4b6b0

File tree

3 files changed

+132
-14
lines changed

3 files changed

+132
-14
lines changed

benchmarks/maxtext_xpk_runner.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class PathwaysConfig:
7878
server_flags: str = ''
7979
proxy_flags: str = ''
8080
worker_flags: str = ''
81+
headless: bool = False
8182

8283

8384
# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
@@ -446,7 +447,7 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
446447

447448
# Get proxy and xla flag string from model config
448449
proxy_flags_string = pw_config.proxy_flags
449-
xla_flags_string = wl_config.model.xla_flags
450+
xla_flags_string = wl_config.model.xla_flags if not pw_config.headless else ''
450451

451452
# Split both proxy_flags_string and xla_flags_string into lists of flags
452453
proxy_flags_list = proxy_flags_string.strip().split()
@@ -457,8 +458,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
457458

458459
# Remove the flags that are specified to be removed.
459460
if (
460-
wl_config.model.pathways_xla_flag_options
461-
and xla_flags.REMOVE in wl_config.model.pathways_xla_flag_options
461+
not pw_config.headless and (wl_config.model.pathways_xla_flag_options
462+
and xla_flags.REMOVE in wl_config.model.pathways_xla_flag_options)
462463
):
463464
flags_to_remove = wl_config.model.pathways_xla_flag_options[
464465
xla_flags.REMOVE
@@ -471,8 +472,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
471472

472473
# Add the flags that are specified to be added.
473474
if (
474-
wl_config.model.pathways_xla_flag_options
475-
and xla_flags.ADD_PROXY in wl_config.model.pathways_xla_flag_options
475+
not pw_config.headless and (wl_config.model.pathways_xla_flag_options
476+
and xla_flags.ADD_PROXY in wl_config.model.pathways_xla_flag_options)
476477
):
477478
flags_to_add = wl_config.model.pathways_xla_flag_options[
478479
xla_flags.ADD_PROXY
@@ -500,8 +501,8 @@ def _get_pathways_worker_flags(wl_config: WorkloadConfig):
500501

501502
# Add the flags that are specified to be added.
502503
if (
503-
wl_config.model.pathways_xla_flag_options
504-
and xla_flags.ADD_WORKER in wl_config.model.pathways_xla_flag_options
504+
not pw_config.headless and (wl_config.model.pathways_xla_flag_options
505+
and xla_flags.ADD_WORKER in wl_config.model.pathways_xla_flag_options)
505506
):
506507
flags_to_add = wl_config.model.pathways_xla_flag_options[
507508
xla_flags.ADD_WORKER
@@ -523,8 +524,8 @@ def _get_pathways_server_flags(wl_config: WorkloadConfig):
523524

524525
# Add the flags that are specified to be added.
525526
if (
526-
wl_config.model.pathways_xla_flag_options
527-
and xla_flags.ADD_SERVER in wl_config.model.pathways_xla_flag_options
527+
not pw_config.headless and (wl_config.model.pathways_xla_flag_options
528+
and xla_flags.ADD_SERVER in wl_config.model.pathways_xla_flag_options)
528529
):
529530
flags_to_add = wl_config.model.pathways_xla_flag_options[
530531
xla_flags.ADD_SERVER
@@ -569,6 +570,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig):
569570
f' --custom-pathways-server-args="{server_flags}" '
570571
f' --custom-pathways-proxy-server-args="{proxy_flags}" '
571572
f' --custom-pathways-worker-args="{worker_flags}" '
573+
f' {"--headless" if pw_config.headless else ""}'
572574
)
573575
return pathways_specific_flags
574576

@@ -582,6 +584,7 @@ def generate_xpk_workload_cmd(
582584
"""Generates a command to run a maxtext model on XPK."""
583585

584586
is_pathways_enabled = wl_config.pathways_config is not None
587+
is_pathways_headless_enabled = wl_config.pathways_config and wl_config.pathways_config.headless
585588

586589
time.localtime()
587590
length_of_random_str = 3
@@ -614,10 +617,12 @@ def generate_xpk_workload_cmd(
614617
wl_config.run_name,
615618
'metrics')
616619

617-
user_command = build_user_command(
618-
name=name,
619-
wl_config=wl_config
620-
)
620+
user_command = ''
621+
if not is_pathways_headless_enabled:
622+
user_command = build_user_command(
623+
name=name,
624+
wl_config=wl_config
625+
)
621626

622627
additional_flags = ''
623628
if not is_pathways_enabled and wl_config.libtpu_type == LibTpuType.CUSTOM:
@@ -641,7 +646,7 @@ def generate_xpk_workload_cmd(
641646
docker_image_flag = f'--docker-image="{wl_config.base_docker_image}"'
642647

643648
upload_metrics_to_bq_cmd = ""
644-
if wl_config.generate_metrics_and_upload_to_big_query:
649+
if wl_config.generate_metrics_and_upload_to_big_query and not is_pathways_headless_enabled:
645650
# TODO (optionally) make it so that this upload step is done on local device instead of within the workload.
646651
args = _build_args_from_config(wl_config)
647652
args_str = ""
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import os
17+
import sys
18+
19+
import benchmarks.recipes.args_helper as helper
20+
import maxtext_xpk_runner as mxr
21+
from recipes.user_configs import cluster_config, xpk_path, pathways_config, base_output_directory, headless_workload_name
22+
23+
24+
def main() -> int:
25+
# Handle command line arguments using args_helper
26+
should_continue = helper.handle_cmd_args(
27+
cluster_config, helper.DELETE, xpk_path=xpk_path
28+
)
29+
30+
if not should_continue:
31+
return 0
32+
33+
num_slices = 2
34+
35+
# Run workloads in the following slice configurations
36+
wl_config = mxr.WorkloadConfig(
37+
model=None,
38+
num_slices=num_slices,
39+
device_type=cluster_config.device_type,
40+
base_output_directory=base_output_directory,
41+
max_restarts=0,
42+
libtpu_type=None,
43+
libtpu_nightly_version="",
44+
base_docker_image="",
45+
pathways_config=pathways_config,
46+
xpk_path=xpk_path,
47+
)
48+
command, name = mxr.generate_xpk_workload_cmd(
49+
cluster_config=cluster_config,
50+
wl_config=wl_config,
51+
workload_name=headless_workload_name,
52+
)
53+
54+
print(f"Name of the workload is: {name} \n")
55+
print(f"XPK command to be used is: {command} \n")
56+
57+
return_code = mxr.run_command_with_updates(command, name)
58+
if return_code != 0:
59+
print(f"Unable to run xpk workload: {name}")
60+
61+
62+
if __name__ == "__main__":
63+
main()

benchmarks/recipes/user_configs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
"""Define user specific configurations for recipes here."""
17+
18+
import os
19+
20+
import maxtext_xpk_runner as mxr
21+
from xpk_configs import XpkClusterConfig
22+
23+
cluster_config = XpkClusterConfig(
24+
cluster_name="test-v5e-32-cluster",
25+
project="cloud-tpu-cluster",
26+
zone="us-south1-a",
27+
device_type="v5litepod-32",
28+
)
29+
xpk_path = "~/xpk"
30+
31+
user = os.environ["USER"]
32+
region = "-".join(cluster_config.zone.split("-")[:-1])
33+
proxy_image = (
34+
f"us-docker.pkg.dev/path/to/{user}/proxy_server"
35+
)
36+
server_image = (
37+
f"us-docker.pkg.dev/path/to/{user}/server"
38+
)
39+
colocated_python_image = f"gcr.io/{cluster_config.project}/path/to/{user}/colocated_python_sidecar"
40+
runner = f"gcr.io/{cluster_config.project}/{user}_maxtext_latest:latest"
41+
base_output_directory = f"gs://{user}-{region}/{user}"
42+
headless = True
43+
pathways_config = mxr.PathwaysConfig(
44+
server_image=server_image,
45+
proxy_server_image=proxy_image,
46+
runner_image=runner,
47+
colocated_python_sidecar_image=colocated_python_image,
48+
headless=headless,
49+
)
50+
headless_workload_name = f"{user[:3]}-headless"

0 commit comments

Comments
 (0)