@@ -78,6 +78,7 @@ class PathwaysConfig:
78
78
server_flags : str = ''
79
79
proxy_flags : str = ''
80
80
worker_flags : str = ''
81
+ headless : bool = False
81
82
82
83
83
84
# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
@@ -446,7 +447,7 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
446
447
447
448
# Get proxy and xla flag string from model config
448
449
proxy_flags_string = pw_config .proxy_flags
449
- xla_flags_string = wl_config .model .xla_flags
450
+ xla_flags_string = wl_config .model .xla_flags if not pw_config . headless else ''
450
451
451
452
# Split both proxy_flags_string and xla_flags_string into lists of flags
452
453
proxy_flags_list = proxy_flags_string .strip ().split ()
@@ -457,8 +458,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
457
458
458
459
# Remove the flags that are specified to be removed.
459
460
if (
460
- wl_config .model .pathways_xla_flag_options
461
- and xla_flags .REMOVE in wl_config .model .pathways_xla_flag_options
461
+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
462
+ and xla_flags .REMOVE in wl_config .model .pathways_xla_flag_options )
462
463
):
463
464
flags_to_remove = wl_config .model .pathways_xla_flag_options [
464
465
xla_flags .REMOVE
@@ -471,8 +472,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
471
472
472
473
# Add the flags that are specified to be added.
473
474
if (
474
- wl_config .model .pathways_xla_flag_options
475
- and xla_flags .ADD_PROXY in wl_config .model .pathways_xla_flag_options
475
+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
476
+ and xla_flags .ADD_PROXY in wl_config .model .pathways_xla_flag_options )
476
477
):
477
478
flags_to_add = wl_config .model .pathways_xla_flag_options [
478
479
xla_flags .ADD_PROXY
@@ -500,8 +501,8 @@ def _get_pathways_worker_flags(wl_config: WorkloadConfig):
500
501
501
502
# Add the flags that are specified to be added.
502
503
if (
503
- wl_config .model .pathways_xla_flag_options
504
- and xla_flags .ADD_WORKER in wl_config .model .pathways_xla_flag_options
504
+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
505
+ and xla_flags .ADD_WORKER in wl_config .model .pathways_xla_flag_options )
505
506
):
506
507
flags_to_add = wl_config .model .pathways_xla_flag_options [
507
508
xla_flags .ADD_WORKER
@@ -523,8 +524,8 @@ def _get_pathways_server_flags(wl_config: WorkloadConfig):
523
524
524
525
# Add the flags that are specified to be added.
525
526
if (
526
- wl_config .model .pathways_xla_flag_options
527
- and xla_flags .ADD_SERVER in wl_config .model .pathways_xla_flag_options
527
+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
528
+ and xla_flags .ADD_SERVER in wl_config .model .pathways_xla_flag_options )
528
529
):
529
530
flags_to_add = wl_config .model .pathways_xla_flag_options [
530
531
xla_flags .ADD_SERVER
@@ -569,6 +570,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig):
569
570
f' --custom-pathways-server-args="{ server_flags } " '
570
571
f' --custom-pathways-proxy-server-args="{ proxy_flags } " '
571
572
f' --custom-pathways-worker-args="{ worker_flags } " '
573
+ f' { "--headless" if pw_config .headless else "" } '
572
574
)
573
575
return pathways_specific_flags
574
576
@@ -582,6 +584,7 @@ def generate_xpk_workload_cmd(
582
584
"""Generates a command to run a maxtext model on XPK."""
583
585
584
586
is_pathways_enabled = wl_config .pathways_config is not None
587
+ is_pathways_headless_enabled = wl_config .pathways_config and wl_config .pathways_config .headless
585
588
586
589
time .localtime ()
587
590
length_of_random_str = 3
@@ -614,10 +617,12 @@ def generate_xpk_workload_cmd(
614
617
wl_config .run_name ,
615
618
'metrics' )
616
619
617
- user_command = build_user_command (
618
- name = name ,
619
- wl_config = wl_config
620
- )
620
+ user_command = ''
621
+ if not is_pathways_headless_enabled :
622
+ user_command = build_user_command (
623
+ name = name ,
624
+ wl_config = wl_config
625
+ )
621
626
622
627
additional_flags = ''
623
628
if not is_pathways_enabled and wl_config .libtpu_type == LibTpuType .CUSTOM :
@@ -641,7 +646,7 @@ def generate_xpk_workload_cmd(
641
646
docker_image_flag = f'--docker-image="{ wl_config .base_docker_image } "'
642
647
643
648
upload_metrics_to_bq_cmd = ""
644
- if wl_config .generate_metrics_and_upload_to_big_query :
649
+ if wl_config .generate_metrics_and_upload_to_big_query and not is_pathways_headless_enabled :
645
650
# TODO (optionally) make it so that this upload step is done on local device instead of within the workload.
646
651
args = _build_args_from_config (wl_config )
647
652
args_str = ""
0 commit comments