-
Notifications
You must be signed in to change notification settings - Fork 562
Add Shardy support for Torch-XLA #9541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Note: this PR still needs tests added to it, that I will add in a future commit before merging this PR in. I was hoping someone with more knowledge would let me know where to add them, and also take a look at the V2 logic to make sure I didn't make any obvious mistakes. We (Tenstorrent) have tested this with our own MLIR compiler that ingests Shardy graphs and we saw that the sharding worked as intended for some basic sharding specs. We were also able to run tensor parallel inference on the Llama 3.1 8B model with these changes. Also, the |
Hi, could I get some eyes on this please? |
torch_xla/csrc/xla_sharding_util.cpp
Outdated
auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>(); | ||
std::vector<xla::OpSharding::Type> subgroup_types; | ||
if (dims_vec.size() > transpose_perm.size()) { | ||
subgroup_types.push_back(xla::OpSharding::REPLICATED); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape_dims
and transpose_perm
only decides the device list, which is unrelated to subgroup_types.
Subgroup types should be another input argument of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the subgroup_types logic into xla_sharding.py
and added that as an input argument instead.
torch_xla/csrc/xla_sharding_util.cpp
Outdated
@@ -218,6 +218,23 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, | |||
return xla::protobuf_util::HaveSameSerialization(a, b); | |||
} | |||
|
|||
xla::OpSharding ShardingUtil::CreateIotaOpSharding( | |||
const py::list& dims, const py::list& reshape_dims, | |||
const py::list& transpose_perm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add CHECK_EQ(reshape_dims.size(), transpose_perm.size())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -142,6 +152,63 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): | |||
sharding_type = int(sharding_type) | |||
return tile_assignment, group_assignment, replication_groups, sharding_type | |||
|
|||
@functools.lru_cache(maxsize=None) | |||
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this function, we convert Shardy sharding (which is based on mesh and axes) into HloShardingV2. Could you please refer to it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for that link! I just updated the logic to match that function.
a29fc3f
to
b108d11
Compare
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…chip training (#2) * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <[email protected]>
…ogic in a later commit soon.
a2fcbf3
to
c5186c6
Compare
c5186c6
to
95be731
Compare
Thanks @ZixuanJiang for your review! The link you gave me to the OpenXLA implementation and your comment here were very useful. Based on your comments I made the following changes:
I tested everything on a Cloud v4-8 TPU courtesy of Google's TPU Research Cloud program. Please let me know if anything else needs to be done! |
7ac5b09
to
3f2af6b
Compare
…ation (#7) This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs. See pytorch#9541 for the upstream PR discussion and additional context. * Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon. * New implementation (WIP) * Fix new implementation * Fix visualize_tensor_sharding function for V2 shardings
This PR implements the necessary changes to support the Shardy dialect within Torch-XLA (relevant issue: #9348):
OpSharding
andXlaShardingSpec
classes (since Shardy doesn't support the V1 shardings that are currently implemented).addStablehloImportPipeline()
pass that performs the SHLO to Shardy conversion."CONVERT_SHLO_TO_SHARDY"
environment variable.