27
27
28
28
import numpy as np
29
29
30
+ from collections .abc import Iterable
30
31
from jax .experimental import mesh_utils
31
32
from jax .experimental .serialize_executable import deserialize_and_load
32
33
from jax .sharding import PartitionSpec as P
34
+
33
35
import jax
34
36
import jax .numpy as jnp
37
+ import jax .tree_util as jtu
35
38
36
39
import optax
37
40
@@ -343,25 +346,51 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
343
346
return total_tflops , learnable_weight_tflops , causal_attention_tflops
344
347
345
348
346
- def assert_params_sufficiently_sharded (params , mesh , tolerance ):
347
- """Checks whether most params are sharded across sharding axis.
349
+ def get_mesh_axes_used_by_tensor_spec (tensor_sharding_spec ):
350
+ """
351
+ Extracts the set of mesh axis names that a tensor's PartitionSpec uses.
352
+
353
+ This function inspects a tensor's sharding specification (PartitionSpec) and
354
+ identifies which mesh axes are actively used for sharding. If a tensor is not
355
+ sharded (i.e., fully replicated), the resulting set will be empty.
356
+
357
+ Args:
358
+ tensor_sharding_spec: The PartitionSpec of a tensor, which defines how it's partitioned across the mesh.
359
+ It can be None or contain strings and iterables representing the mesh axes.
360
+ all_mesh_axis_names: A collection of all available mesh axis names in the current device mesh.
361
+
362
+ Returns:
363
+ A set of strings, where each string is a mesh axis name used by the
364
+ tensor's sharding spec. Returns an empty set for unsharded tensors.
365
+ """
366
+ # Flatten the sharding spec, as it can contain nested iterables (e.g., ('data', 'mdl')).
367
+ tensor_sharding_spec = sum (
368
+ [
369
+ [axis ] if isinstance (axis , str ) else list (axis ) if isinstance (axis , Iterable ) else []
370
+ for axis in tensor_sharding_spec
371
+ ],
372
+ [],
373
+ )
374
+ return tensor_sharding_spec
375
+
376
+
377
+ def _get_nontrival_mesh_axes (mesh ):
378
+ """
379
+ Returns mesh axes from config that are valid and have more than one shard.
348
380
349
- This function determines whether the majority of parameters are distributed
350
- across a specified sharding axes with an acceptable tolerance. It compares the
351
- current distribution to a scenario where all parameters are fully sharded
352
- across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes.
381
+ This function identifies which of the predefined potential sharding axes are
382
+ actually present in the current device mesh and are configured with a size
383
+ greater than one (i.e., are actually sharded).
353
384
354
385
Args:
355
- params: params of the model state
356
- mesh: mesh constructed from config
357
- tolerance: float between 0.0 and 1.0 representing the allowed percentage of
358
- non-sharded parameters.
386
+ mesh: The device mesh object, which contains information about the mesh topology, including axis names and their sizes.
387
+
359
388
Returns:
360
- bool: True if the majority of parameters are sufficiently sharded
389
+ A set of strings, where each string is a mesh axis name that is both
390
+ pre-configured as a target for sharding and has more than one shard in the mesh.
361
391
"""
362
- total_num_params = max_utils .calculate_num_params_from_pytree (params )
363
- product_num_devices_for_weight_sharding = 1
364
- for axis in [
392
+
393
+ target_sharding_axes_config = [
365
394
"fsdp" ,
366
395
"fsdp_transpose" ,
367
396
"sequence" ,
@@ -372,19 +401,156 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance):
372
401
"tensor_sequence" ,
373
402
"stage" ,
374
403
"expert" ,
375
- ]:
376
- product_num_devices_for_weight_sharding *= mesh .shape [axis ]
377
- total_num_params_per_chip = max_utils .calculate_total_params_per_chip (params )
378
- perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
379
- assert total_num_params_per_chip >= perfectly_sharded_params_per_chip , (
380
- "Number of parameters per chip must not be less than in the ideal sharded "
381
- "scenario across `fsdp`, `fsdp_transpose`, `context`, `sequence`, `tensor`, `tensor_transpose`, "
382
- "`tensor_sequence`, `stage`, `expert` axes."
383
- )
384
- unsharded_param_perc = total_num_params_per_chip / perfectly_sharded_params_per_chip - 1
385
- assert unsharded_param_perc < tolerance , (
386
- f"Number of unsharded parameters exceeds tolerance { tolerance * 100 } % "
387
- f"of total parameters with a value of { unsharded_param_perc * 100 } %."
404
+ ]
405
+
406
+ # Filter the target axes to find those that exist in the current mesh
407
+ # and have a size greater than 1, meaning they are actually used for sharding.
408
+ return {axis for axis in target_sharding_axes_config if axis in mesh .axis_names and mesh .shape [axis ] > 1 }
409
+
410
+
411
+ def _analyze_sharding (params , mesh , valid_target_mesh_axes ):
412
+ """
413
+ Analyzes parameters to find which are unsharded on any valid mesh axis.
414
+
415
+ This function iterates through all parameters in a model, checking their
416
+ sharding specifications. It identifies parameters that are not sharded along any
417
+ of the provided valid target axes (i.e., they are fully replicated across these axes).
418
+
419
+ Args:
420
+ params: A PyTree of model parameters.
421
+ mesh: The device mesh object.
422
+ valid_target_mesh_axes: A set of mesh axis names that are considered valid targets for sharding.
423
+
424
+ Returns:
425
+ A tuple containing:
426
+ - unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be
427
+ unsharded on the target axes.
428
+ - problematic_tensors_details (list): A list of dictionaries, where each
429
+ dictionary contains details about a tensor that is not sharded on any of the target axes.
430
+ """
431
+ unsharded_params_total_size = 0 # Initialize a counter for the size of unsharded parameters.
432
+ problematic_tensors_details = [] # Initialize a list to store details of problematic tensors.
433
+
434
+ # Get a flattened list of all parameters (leaves) in the PyTree, along with their paths.
435
+ all_params_leaves = jtu .tree_leaves_with_path (params )
436
+
437
+ for path , p_leaf in all_params_leaves : # Iterate over each parameter leaf
438
+ param_name_str = jtu .keystr (path ) # Convert the tree path to a readable string
439
+
440
+ # Check that sharding and spec exist and are valid
441
+ sharding = getattr (p_leaf , "sharding" , None )
442
+ spec = getattr (sharding , "spec" , None )
443
+ assert sharding is not None and spec is not None and isinstance (spec , P ), (
444
+ f"Parameter '{ param_name_str } ' is missing a valid '.sharding.spec'."
445
+ "Expected 'p_leaf.sharding.spec' to be a non-null 'partitionspec'."
446
+ )
447
+
448
+ current_sharding_spec = p_leaf .sharding .spec # Extract the current tensor's sharding spec
449
+ # Identify axes used for sharding
450
+ mesh_axes_used = get_mesh_axes_used_by_tensor_spec (current_sharding_spec )
451
+ # Check if the parameter is sharded on all the valid target axes.
452
+ is_sharded_on_all_target_axis = all (axis in mesh_axes_used for axis in valid_target_mesh_axes )
453
+
454
+ # If the parameter is not sharded on all of the target axes, it's considered "problematic."
455
+ if not is_sharded_on_all_target_axis :
456
+ unsharded_params_total_size += p_leaf .size # Add to total unsharded parameter size
457
+ unsharded_axes = set (valid_target_mesh_axes ) - set (mesh_axes_used )
458
+ # Add detailed info to list of problematic tensors
459
+ problematic_tensors_details .append (
460
+ {
461
+ "name" : param_name_str , # Tensor name
462
+ "size" : p_leaf .size , # tensor size
463
+ "shape" : p_leaf .shape , # tensor shape
464
+ "spec" : str (current_sharding_spec ), # Tensor sharding spec as string
465
+ "available_axes" : sorted (list (valid_target_mesh_axes )), # Axes that could be used for sharding
466
+ "unsharded_axes" : sorted (list (unsharded_axes )), # Unsharded axes
467
+ }
468
+ )
469
+ # Return the total size of unsharded parameters and the list of problematic tensors.
470
+ return unsharded_params_total_size , problematic_tensors_details # Return results
471
+
472
+
473
+ def _raise_if_unsharded_exceeds_tolerance (unsharded_size , total_size , tolerance , problematic_tensors_details ):
474
+ """
475
+ Raises an AssertionError if the percentage of unsharded parameters exceeds the given tolerance.
476
+
477
+ This function calculates the proportion of model parameters that are unsharded
478
+ and compares it against a specified tolerance. If the tolerance is exceeded,
479
+ it constructs and raises a detailed error message.
480
+
481
+ Args:
482
+ unsharded_size: The total size of parameters not sharded on target axes.
483
+ total_size: The total size of all parameters in the model.
484
+ tolerance: A float (e.g., 0.05 for 5%) representing the maximum allowed percentage of unsharded parameters.
485
+ problematic_tensors_details: A list of details about the unsharded tensors,
486
+ used to generate an informative error message.
487
+
488
+ Raises:
489
+ AssertionError: If the percentage of unsharded parameters is greater than the tolerance.
490
+ """
491
+ if total_size <= 0 :
492
+ raise ValueError ("Total size must be greater than zero." )
493
+
494
+ # Calculate the percentage of unsharded parameters.
495
+ unsharded_param_perc = unsharded_size / total_size
496
+
497
+ # If the percentage is over the tolerance, prepare and raise an error.
498
+ if unsharded_param_perc > tolerance :
499
+ # Sort the problematic tensors by size to show the largest ones first.
500
+ problematic_tensors_details .sort (key = lambda x : x ["size" ], reverse = True )
501
+
502
+ # Begin constructing the error message.
503
+ error_msg_lines = [
504
+ f"Unsharded parameter percentage ({ unsharded_param_perc :.2%} )" f"exceeds tolerance ({ tolerance :.2%} )."
505
+ ]
506
+ # Add a header explaining the issue.
507
+ error_msg_lines .append (
508
+ "The following large tensors are replicated (unsharded) but could be sharded on at "
509
+ "least one of the available axes:"
510
+ )
511
+ # Add details for the top 5 largest problematic tensors.
512
+ for detail in problematic_tensors_details [:5 ]: # Show top 5 largest problematic tensors
513
+ error_msg_lines .append (
514
+ f" - Name: { detail ['name' ]} (Size: { detail ['size' ]} , Shape: { detail ['spec' ]} , Spec: { detail ['spec' ]} ) "
515
+ f" is unsharded on axis: { detail ['unsharded_axes' ]} "
516
+ f" could be sharded on: { detail ['available_axes' ]} "
517
+ )
518
+
519
+ # Raise the assertion error with the combined, formatted message.
520
+ raise AssertionError ("\n " .join (error_msg_lines ))
521
+
522
+
523
+ def assert_params_sufficiently_sharded (params , mesh , tolerance ):
524
+ """
525
+ Asserts that the total size of replicated parameters is within a given tolerance.
526
+
527
+ This is the main function that orchestrates the sharding analysis. It determines
528
+ the total number of parameters, identifies valid sharding axes, analyzes the
529
+ sharding of all parameters, and then raises an error if the amount of
530
+ unsharded parameters exceeds the specified tolerance.
531
+
532
+ Args:
533
+ params: A PyTree of model parameters.
534
+ mesh: The device mesh object.
535
+ tolerance: A float representing the maximum allowed percentage of unsharded parameters.
536
+ """
537
+ # Calculate the total size of all parameters in the model.
538
+ total_num_params = max_utils .calculate_bytes_from_pytree (params )
539
+
540
+ # Get the set of nontrival mesh axes that can be used for sharding.
541
+ valid_target_mesh_axes = _get_nontrival_mesh_axes (mesh )
542
+ # If there are no valid axes to shard along, there's nothing to check, so we can exit.
543
+ if not valid_target_mesh_axes :
544
+ return # Exit early
545
+
546
+ # Analyze the parameters to find the total size of unsharded parameters
547
+ # and get details on which tensors are problematic.
548
+ unsharded_params_total_size , problematic_tensors_details = _analyze_sharding (params , mesh , valid_target_mesh_axes )
549
+
550
+ # Check if the amount of unsharded parameters is within the tolerance and
551
+ # raise an exception if it is not.
552
+ _raise_if_unsharded_exceeds_tolerance (
553
+ unsharded_params_total_size , total_num_params , tolerance , problematic_tensors_details
388
554
)
389
555
390
556
@@ -848,3 +1014,67 @@ def schedule(step):
848
1014
boundaries .append (warmup_steps + cos_steps + constant_zero_steps )
849
1015
850
1016
return optax .join_schedules (pieces , boundaries )
1017
+
1018
+
1019
+ def get_formatted_sharding_annotations (params , mesh = None ):
1020
+ """
1021
+ Generates a readable string report of sharding annotations for all parameters.
1022
+
1023
+ This function iterates through a PyTree of model parameters and inspects the
1024
+ sharding information attached to each parameter (leaf). It creates a
1025
+ human-readable summary that is useful for debugging sharding configurations.
1026
+
1027
+ Args:
1028
+ params: The PyTree of model parameters to inspect.
1029
+ mesh: (Optional) The device mesh. If provided, its axis names and shape
1030
+ are included in the report for additional context.
1031
+
1032
+ Returns:
1033
+ A single string containing the formatted report of sharding annotations
1034
+ for every parameter, with each entry on a new line.
1035
+ """
1036
+ # Initialize a list to hold the lines of the report, starting with a title.
1037
+ annotation_lines = ["Comprehensice Weight Sharding Annotations:" ]
1038
+
1039
+ # If a mesh object is provided, add its details to the report header.
1040
+ if mesh :
1041
+ annotation_lines .append (f"Mesh axes: { mesh .axis_names } , Mesh shape: { mesh .shape } " )
1042
+ annotation_lines .append ("-" * 30 )
1043
+
1044
+ # Get a flattened list of all parameters (leaves) and their corresponding paths in the PyTree.
1045
+ all_params_leaves = jtu .tree_leaves_with_path (params )
1046
+
1047
+ # Loop through each parameter leaf in the flattened list.
1048
+ for path , p_leaf in all_params_leaves :
1049
+ # Convert the parameter's path (a sequence of keys) into a readable string name.
1050
+ param_name_str = jtu .keystr (path )
1051
+ # Get the shape of the parameter as a string.
1052
+ shape_str = str (p_leaf .shape )
1053
+ # Set a default description for sharding, in case none is found.
1054
+ sharding_desc = "N/A"
1055
+
1056
+ # Check if the parameter leaf has a 'sharding' attribute.
1057
+ if hasattr (p_leaf , "sharding" ):
1058
+ # Case 1: Standard JAX sharding with a PartitionSpec.
1059
+ if hasattr (p_leaf .sharding , "spec" ) and p_leaf .sharding .spec is not None :
1060
+ # The spec is a tuple (PartitionSpec), format it for readability.
1061
+ spec_parts = []
1062
+ for item in p_leaf .sharding .spec :
1063
+ # Represent None as "Replicated" to make it explicit.
1064
+ spec_parts .append (str (item ) if item is not None else "Relicated" )
1065
+ sharding_desc = f"PartitionSpec({ ', ' .join (spec_parts )} )"
1066
+ # Case 2: The parameter is explicitly marked as fully replicated.
1067
+ elif hasattr (p_leaf .sharding , "spec" ) and p_leaf .sharding .spec is None :
1068
+ sharding_desc = "Fully Replicated (spec is None)"
1069
+ # Case 3: A generic fallback if a sharding object exists but has no recognized spec attribute.
1070
+ else :
1071
+ # Print the string representation of the sharding object itself.
1072
+ sharding_desc = str (p_leaf .sharding )
1073
+ # Case 4: The parameter has no .sharding attribute at all.
1074
+ else :
1075
+ sharding_desc = "No .sharding attribute found"
1076
+
1077
+ # Append the formatted details for the current parameter to our list of lines.
1078
+ annotation_lines .append (f" - Param: { param_name_str } \n " f" Shape: { shape_str } \n " f" Sharding: { sharding_desc } " )
1079
+ # Join all the collected lines into a single string, separated by newlines.
1080
+ return "\n " .join (annotation_lines )
0 commit comments