Skip to content

Commit 244a071

Browse files
author
maxtext authors
committed
Merge pull request #1810 from AI-Hypercomputer:weifan-shard
PiperOrigin-RevId: 771220503
2 parents 3296117 + 51be412 commit 244a071

File tree

5 files changed

+421
-35
lines changed

5 files changed

+421
-35
lines changed

MaxText/maxtext_utils.py

Lines changed: 257 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727

2828
import numpy as np
2929

30+
from collections.abc import Iterable
3031
from jax.experimental import mesh_utils
3132
from jax.experimental.serialize_executable import deserialize_and_load
3233
from jax.sharding import PartitionSpec as P
34+
3335
import jax
3436
import jax.numpy as jnp
37+
import jax.tree_util as jtu
3538

3639
import optax
3740

@@ -343,25 +346,51 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
343346
return total_tflops, learnable_weight_tflops, causal_attention_tflops
344347

345348

346-
def assert_params_sufficiently_sharded(params, mesh, tolerance):
347-
"""Checks whether most params are sharded across sharding axis.
349+
def get_mesh_axes_used_by_tensor_spec(tensor_sharding_spec):
350+
"""
351+
Extracts the set of mesh axis names that a tensor's PartitionSpec uses.
352+
353+
This function inspects a tensor's sharding specification (PartitionSpec) and
354+
identifies which mesh axes are actively used for sharding. If a tensor is not
355+
sharded (i.e., fully replicated), the resulting set will be empty.
356+
357+
Args:
358+
tensor_sharding_spec: The PartitionSpec of a tensor, which defines how it's partitioned across the mesh.
359+
It can be None or contain strings and iterables representing the mesh axes.
360+
all_mesh_axis_names: A collection of all available mesh axis names in the current device mesh.
361+
362+
Returns:
363+
A set of strings, where each string is a mesh axis name used by the
364+
tensor's sharding spec. Returns an empty set for unsharded tensors.
365+
"""
366+
# Flatten the sharding spec, as it can contain nested iterables (e.g., ('data', 'mdl')).
367+
tensor_sharding_spec = sum(
368+
[
369+
[axis] if isinstance(axis, str) else list(axis) if isinstance(axis, Iterable) else []
370+
for axis in tensor_sharding_spec
371+
],
372+
[],
373+
)
374+
return tensor_sharding_spec
375+
376+
377+
def _get_nontrival_mesh_axes(mesh):
378+
"""
379+
Returns mesh axes from config that are valid and have more than one shard.
348380
349-
This function determines whether the majority of parameters are distributed
350-
across a specified sharding axes with an acceptable tolerance. It compares the
351-
current distribution to a scenario where all parameters are fully sharded
352-
across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes.
381+
This function identifies which of the predefined potential sharding axes are
382+
actually present in the current device mesh and are configured with a size
383+
greater than one (i.e., are actually sharded).
353384
354385
Args:
355-
params: params of the model state
356-
mesh: mesh constructed from config
357-
tolerance: float between 0.0 and 1.0 representing the allowed percentage of
358-
non-sharded parameters.
386+
mesh: The device mesh object, which contains information about the mesh topology, including axis names and their sizes.
387+
359388
Returns:
360-
bool: True if the majority of parameters are sufficiently sharded
389+
A set of strings, where each string is a mesh axis name that is both
390+
pre-configured as a target for sharding and has more than one shard in the mesh.
361391
"""
362-
total_num_params = max_utils.calculate_num_params_from_pytree(params)
363-
product_num_devices_for_weight_sharding = 1
364-
for axis in [
392+
393+
target_sharding_axes_config = [
365394
"fsdp",
366395
"fsdp_transpose",
367396
"sequence",
@@ -372,19 +401,156 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance):
372401
"tensor_sequence",
373402
"stage",
374403
"expert",
375-
]:
376-
product_num_devices_for_weight_sharding *= mesh.shape[axis]
377-
total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params)
378-
perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
379-
assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, (
380-
"Number of parameters per chip must not be less than in the ideal sharded "
381-
"scenario across `fsdp`, `fsdp_transpose`, `context`, `sequence`, `tensor`, `tensor_transpose`, "
382-
"`tensor_sequence`, `stage`, `expert` axes."
383-
)
384-
unsharded_param_perc = total_num_params_per_chip / perfectly_sharded_params_per_chip - 1
385-
assert unsharded_param_perc < tolerance, (
386-
f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% "
387-
f"of total parameters with a value of {unsharded_param_perc * 100}%."
404+
]
405+
406+
# Filter the target axes to find those that exist in the current mesh
407+
# and have a size greater than 1, meaning they are actually used for sharding.
408+
return {axis for axis in target_sharding_axes_config if axis in mesh.axis_names and mesh.shape[axis] > 1}
409+
410+
411+
def _analyze_sharding(params, mesh, valid_target_mesh_axes):
412+
"""
413+
Analyzes parameters to find which are unsharded on any valid mesh axis.
414+
415+
This function iterates through all parameters in a model, checking their
416+
sharding specifications. It identifies parameters that are not sharded along any
417+
of the provided valid target axes (i.e., they are fully replicated across these axes).
418+
419+
Args:
420+
params: A PyTree of model parameters.
421+
mesh: The device mesh object.
422+
valid_target_mesh_axes: A set of mesh axis names that are considered valid targets for sharding.
423+
424+
Returns:
425+
A tuple containing:
426+
- unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be
427+
unsharded on the target axes.
428+
- problematic_tensors_details (list): A list of dictionaries, where each
429+
dictionary contains details about a tensor that is not sharded on any of the target axes.
430+
"""
431+
unsharded_params_total_size = 0 # Initialize a counter for the size of unsharded parameters.
432+
problematic_tensors_details = [] # Initialize a list to store details of problematic tensors.
433+
434+
# Get a flattened list of all parameters (leaves) in the PyTree, along with their paths.
435+
all_params_leaves = jtu.tree_leaves_with_path(params)
436+
437+
for path, p_leaf in all_params_leaves: # Iterate over each parameter leaf
438+
param_name_str = jtu.keystr(path) # Convert the tree path to a readable string
439+
440+
# Check that sharding and spec exist and are valid
441+
sharding = getattr(p_leaf, "sharding", None)
442+
spec = getattr(sharding, "spec", None)
443+
assert sharding is not None and spec is not None and isinstance(spec, P), (
444+
f"Parameter '{param_name_str}' is missing a valid '.sharding.spec'."
445+
"Expected 'p_leaf.sharding.spec' to be a non-null 'partitionspec'."
446+
)
447+
448+
current_sharding_spec = p_leaf.sharding.spec # Extract the current tensor's sharding spec
449+
# Identify axes used for sharding
450+
mesh_axes_used = get_mesh_axes_used_by_tensor_spec(current_sharding_spec)
451+
# Check if the parameter is sharded on all the valid target axes.
452+
is_sharded_on_all_target_axis = all(axis in mesh_axes_used for axis in valid_target_mesh_axes)
453+
454+
# If the parameter is not sharded on all of the target axes, it's considered "problematic."
455+
if not is_sharded_on_all_target_axis:
456+
unsharded_params_total_size += p_leaf.size # Add to total unsharded parameter size
457+
unsharded_axes = set(valid_target_mesh_axes) - set(mesh_axes_used)
458+
# Add detailed info to list of problematic tensors
459+
problematic_tensors_details.append(
460+
{
461+
"name": param_name_str, # Tensor name
462+
"size": p_leaf.size, # tensor size
463+
"shape": p_leaf.shape, # tensor shape
464+
"spec": str(current_sharding_spec), # Tensor sharding spec as string
465+
"available_axes": sorted(list(valid_target_mesh_axes)), # Axes that could be used for sharding
466+
"unsharded_axes": sorted(list(unsharded_axes)), # Unsharded axes
467+
}
468+
)
469+
# Return the total size of unsharded parameters and the list of problematic tensors.
470+
return unsharded_params_total_size, problematic_tensors_details # Return results
471+
472+
473+
def _raise_if_unsharded_exceeds_tolerance(unsharded_size, total_size, tolerance, problematic_tensors_details):
474+
"""
475+
Raises an AssertionError if the percentage of unsharded parameters exceeds the given tolerance.
476+
477+
This function calculates the proportion of model parameters that are unsharded
478+
and compares it against a specified tolerance. If the tolerance is exceeded,
479+
it constructs and raises a detailed error message.
480+
481+
Args:
482+
unsharded_size: The total size of parameters not sharded on target axes.
483+
total_size: The total size of all parameters in the model.
484+
tolerance: A float (e.g., 0.05 for 5%) representing the maximum allowed percentage of unsharded parameters.
485+
problematic_tensors_details: A list of details about the unsharded tensors,
486+
used to generate an informative error message.
487+
488+
Raises:
489+
AssertionError: If the percentage of unsharded parameters is greater than the tolerance.
490+
"""
491+
if total_size <= 0:
492+
raise ValueError("Total size must be greater than zero.")
493+
494+
# Calculate the percentage of unsharded parameters.
495+
unsharded_param_perc = unsharded_size / total_size
496+
497+
# If the percentage is over the tolerance, prepare and raise an error.
498+
if unsharded_param_perc > tolerance:
499+
# Sort the problematic tensors by size to show the largest ones first.
500+
problematic_tensors_details.sort(key=lambda x: x["size"], reverse=True)
501+
502+
# Begin constructing the error message.
503+
error_msg_lines = [
504+
f"Unsharded parameter percentage ({unsharded_param_perc:.2%})" f"exceeds tolerance ({tolerance:.2%})."
505+
]
506+
# Add a header explaining the issue.
507+
error_msg_lines.append(
508+
"The following large tensors are replicated (unsharded) but could be sharded on at "
509+
"least one of the available axes:"
510+
)
511+
# Add details for the top 5 largest problematic tensors.
512+
for detail in problematic_tensors_details[:5]: # Show top 5 largest problematic tensors
513+
error_msg_lines.append(
514+
f" - Name: {detail['name']}(Size: {detail['size']}, Shape: {detail['spec']}, Spec: {detail['spec']}) "
515+
f" is unsharded on axis: {detail['unsharded_axes']}"
516+
f" could be sharded on: {detail['available_axes']}"
517+
)
518+
519+
# Raise the assertion error with the combined, formatted message.
520+
raise AssertionError("\n".join(error_msg_lines))
521+
522+
523+
def assert_params_sufficiently_sharded(params, mesh, tolerance):
524+
"""
525+
Asserts that the total size of replicated parameters is within a given tolerance.
526+
527+
This is the main function that orchestrates the sharding analysis. It determines
528+
the total number of parameters, identifies valid sharding axes, analyzes the
529+
sharding of all parameters, and then raises an error if the amount of
530+
unsharded parameters exceeds the specified tolerance.
531+
532+
Args:
533+
params: A PyTree of model parameters.
534+
mesh: The device mesh object.
535+
tolerance: A float representing the maximum allowed percentage of unsharded parameters.
536+
"""
537+
# Calculate the total size of all parameters in the model.
538+
total_num_params = max_utils.calculate_bytes_from_pytree(params)
539+
540+
# Get the set of nontrival mesh axes that can be used for sharding.
541+
valid_target_mesh_axes = _get_nontrival_mesh_axes(mesh)
542+
# If there are no valid axes to shard along, there's nothing to check, so we can exit.
543+
if not valid_target_mesh_axes:
544+
return # Exit early
545+
546+
# Analyze the parameters to find the total size of unsharded parameters
547+
# and get details on which tensors are problematic.
548+
unsharded_params_total_size, problematic_tensors_details = _analyze_sharding(params, mesh, valid_target_mesh_axes)
549+
550+
# Check if the amount of unsharded parameters is within the tolerance and
551+
# raise an exception if it is not.
552+
_raise_if_unsharded_exceeds_tolerance(
553+
unsharded_params_total_size, total_num_params, tolerance, problematic_tensors_details
388554
)
389555

390556

@@ -848,3 +1014,67 @@ def schedule(step):
8481014
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
8491015

8501016
return optax.join_schedules(pieces, boundaries)
1017+
1018+
1019+
def get_formatted_sharding_annotations(params, mesh=None):
1020+
"""
1021+
Generates a readable string report of sharding annotations for all parameters.
1022+
1023+
This function iterates through a PyTree of model parameters and inspects the
1024+
sharding information attached to each parameter (leaf). It creates a
1025+
human-readable summary that is useful for debugging sharding configurations.
1026+
1027+
Args:
1028+
params: The PyTree of model parameters to inspect.
1029+
mesh: (Optional) The device mesh. If provided, its axis names and shape
1030+
are included in the report for additional context.
1031+
1032+
Returns:
1033+
A single string containing the formatted report of sharding annotations
1034+
for every parameter, with each entry on a new line.
1035+
"""
1036+
# Initialize a list to hold the lines of the report, starting with a title.
1037+
annotation_lines = ["Comprehensice Weight Sharding Annotations:"]
1038+
1039+
# If a mesh object is provided, add its details to the report header.
1040+
if mesh:
1041+
annotation_lines.append(f"Mesh axes: {mesh.axis_names}, Mesh shape: {mesh.shape}")
1042+
annotation_lines.append("-" * 30)
1043+
1044+
# Get a flattened list of all parameters (leaves) and their corresponding paths in the PyTree.
1045+
all_params_leaves = jtu.tree_leaves_with_path(params)
1046+
1047+
# Loop through each parameter leaf in the flattened list.
1048+
for path, p_leaf in all_params_leaves:
1049+
# Convert the parameter's path (a sequence of keys) into a readable string name.
1050+
param_name_str = jtu.keystr(path)
1051+
# Get the shape of the parameter as a string.
1052+
shape_str = str(p_leaf.shape)
1053+
# Set a default description for sharding, in case none is found.
1054+
sharding_desc = "N/A"
1055+
1056+
# Check if the parameter leaf has a 'sharding' attribute.
1057+
if hasattr(p_leaf, "sharding"):
1058+
# Case 1: Standard JAX sharding with a PartitionSpec.
1059+
if hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is not None:
1060+
# The spec is a tuple (PartitionSpec), format it for readability.
1061+
spec_parts = []
1062+
for item in p_leaf.sharding.spec:
1063+
# Represent None as "Replicated" to make it explicit.
1064+
spec_parts.append(str(item) if item is not None else "Relicated")
1065+
sharding_desc = f"PartitionSpec({', '.join(spec_parts)})"
1066+
# Case 2: The parameter is explicitly marked as fully replicated.
1067+
elif hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is None:
1068+
sharding_desc = "Fully Replicated (spec is None)"
1069+
# Case 3: A generic fallback if a sharding object exists but has no recognized spec attribute.
1070+
else:
1071+
# Print the string representation of the sharding object itself.
1072+
sharding_desc = str(p_leaf.sharding)
1073+
# Case 4: The parameter has no .sharding attribute at all.
1074+
else:
1075+
sharding_desc = "No .sharding attribute found"
1076+
1077+
# Append the formatted details for the current parameter to our list of lines.
1078+
annotation_lines.append(f" - Param: {param_name_str}\n" f" Shape: {shape_str}\n" f" Sharding: {sharding_desc}")
1079+
# Join all the collected lines into a single string, separated by newlines.
1080+
return "\n".join(annotation_lines)

MaxText/pyconfig.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,17 +829,17 @@ def pipeline_first_axis(raw_keys):
829829
raw_keys = pipeline_first_axis(raw_keys)
830830
num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"])
831831
if raw_keys["pipeline_parallel_layers"] == -1:
832-
if raw_keys["decoder_block"]=="deepseek":
832+
if raw_keys["decoder_block"] == "deepseek":
833833
moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"]
834834
raw_keys["pipeline_parallel_layers"] = moe_layers
835835
else:
836836
raw_keys["pipeline_parallel_layers"] = raw_keys["num_decoder_layers"]
837837
else:
838-
if raw_keys["decoder_block"]=="deepseek":
838+
if raw_keys["decoder_block"] == "deepseek":
839839
moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"]
840840
assert (
841-
raw_keys["pipeline_parallel_layers"] <= moe_layers
842-
), f"You can only pipeline a subset of the moe decoder layers for deepseek, but you requested to pipeline {raw_keys['pipeline_parallel_layers']} with pipeline_parallel_layers and there are only {moe_layers} decoder layers."
841+
raw_keys["pipeline_parallel_layers"] <= moe_layers
842+
), f"You can only pipeline a subset of the moe decoder layers for deepseek, but you requested to pipeline {raw_keys['pipeline_parallel_layers']} with pipeline_parallel_layers and there are only {moe_layers} decoder layers."
843843
else:
844844
assert (
845845
raw_keys["pipeline_parallel_layers"] <= raw_keys["num_decoder_layers"]

MaxText/tests/integration_tests/train_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,6 @@ def test_gpu_cudnn_flash_jax(self):
334334
]
335335
train_main(cudnn_flash_jax)
336336

337+
337338
if __name__ == "__main__":
338339
absltest.main()

0 commit comments

Comments
 (0)