Skip to content

Conversation

hshahTT
Copy link

@hshahTT hshahTT commented Aug 5, 2025

This PR implements the necessary changes to support the Shardy dialect within Torch-XLA (relevant issue: #9348):

  1. Adding support for V2 HLO sharding within the OpSharding and XlaShardingSpec classes (since Shardy doesn't support the V1 shardings that are currently implemented).
  2. Add the OpenXLA addStablehloImportPipeline() pass that performs the SHLO to Shardy conversion.
  3. This is protected by the "CONVERT_SHLO_TO_SHARDY" environment variable.

@hshahTT
Copy link
Author

hshahTT commented Aug 5, 2025

Note: this PR still needs tests added to it, that I will add in a future commit before merging this PR in. I was hoping someone with more knowledge would let me know where to add them, and also take a look at the V2 logic to make sure I didn't make any obvious mistakes.

We (Tenstorrent) have tested this with our own MLIR compiler that ingests Shardy graphs and we saw that the sharding worked as intended for some basic sharding specs. We were also able to run tensor parallel inference on the Llama 3.1 8B model with these changes.

Also, the visualize_sharding_spec() function is broken when the "CONVERT_SHLO_TO_SHARDY" environment variable is set, since that function expects the sharding string to be in V1 format. I will add that fix in a future commit once I know where to add the tests and can make sure I've accounted for all possible sharding specs correctly.

@hshahTT hshahTT mentioned this pull request Aug 5, 2025
@hshahTT
Copy link
Author

hshahTT commented Aug 14, 2025

Hi, could I get some eyes on this please?

auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>();
std::vector<xla::OpSharding::Type> subgroup_types;
if (dims_vec.size() > transpose_perm.size()) {
subgroup_types.push_back(xla::OpSharding::REPLICATED);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape_dims and transpose_perm only decides the device list, which is unrelated to subgroup_types.

Subgroup types should be another input argument of this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the subgroup_types logic into xla_sharding.py and added that as an input argument instead.

@@ -218,6 +218,23 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
return xla::protobuf_util::HaveSameSerialization(a, b);
}

xla::OpSharding ShardingUtil::CreateIotaOpSharding(
const py::list& dims, const py::list& reshape_dims,
const py::list& transpose_perm) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CHECK_EQ(reshape_dims.size(), transpose_perm.size())

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -142,6 +152,63 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
sharding_type = int(sharding_type)
return tile_assignment, group_assignment, replication_groups, sharding_type

@functools.lru_cache(maxsize=None)
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function, we convert Shardy sharding (which is based on mesh and axes) into HloShardingV2. Could you please refer to it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for that link! I just updated the logic to match that function.

hshahTT and others added 5 commits August 30, 2025 18:12
…n PyTorch/XLA (#1)

Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things:

- Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]).
- Converts the new GSPMD module with the V2 annotations into a Shardy module.
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <[email protected]>
@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from a2fcbf3 to c5186c6 Compare August 30, 2025 18:16
@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from c5186c6 to 95be731 Compare August 30, 2025 18:24
@hshahTT
Copy link
Author

hshahTT commented Aug 30, 2025

Thanks @ZixuanJiang for your review! The link you gave me to the OpenXLA implementation and your comment here were very useful.

Based on your comments I made the following changes:

  • xla_sharding.py:

    • Modified the logic in _get_op_sharding_args_v2() to match the convertToHloSharding() function within OpenXLA
  • test_xla_sharding.py:

    • Modified the existing test cases to test the V2 sharding logic when the CONVERT_SHLO_TO_SHARDY environment variable is set.
    • The way it is currently setup requires you to set the environment variable before running the test (i.e., by running CONVERT_SHLO_TO_SHARDY=1 python test_xla_sharding.py in whatever Bash script actually calls it in CI).
    • Another way is to parameterize the testing class itself based on whether the env var should be set or not (meaning we run all the tests once with the env variable unset and then again with it set). I can do this but may need to add the parameterized pip module as a test dependency to do it cleanly.
  • debugging.py:

    • Added support for V2 shardings within the visualize_tensor_sharding() debugging function by converting them into V1 shardings first (which are already supported) via the construct_v1_sharding_str() function.
    • I added a _get_xla_op_sharding_v2_params Pybind function inside init_python_bindings.cpp that takes a tensor and returns all the V2 sharding parameters needed to represent it: tile_assignment_dims, reshape_dims, transpose_perm, is_last_tile_dim_replicate.
    • Another way I can get those params is by just parsing the V2 sharding string directly, since it already has the format:
      {devices=[tile_assignment_dims]<=[reshape_dims]T(transpose_perm) last_tile_dim_replicate}
      
      but I thought that the Pybind function would be more readable.
  • test_spmd_debugging.py: Added the ConvertV2ShardingToV1Test testing class to test the construct_v1_sharding_str() function.

I tested everything on a Cloud v4-8 TPU courtesy of Google's TPU Research Cloud program. Please let me know if anything else needs to be done!

@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from 7ac5b09 to 3f2af6b Compare September 2, 2025 21:03
hshahTT added a commit to tenstorrent/pytorch-xla that referenced this pull request Sep 3, 2025
…ation (#7)

This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs.

See pytorch#9541 for the upstream PR discussion and additional context.

* Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.

* New implementation (WIP)

* Fix new implementation

* Fix visualize_tensor_sharding function for V2 shardings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants