diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 34221d375e9..5ffa74f9886 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -17,6 +17,7 @@ import torch_xla.distributed.spmd as xs from torch_xla.distributed.spmd import XLAShardedTensor from torch_xla.distributed.spmd import Mesh +from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str import test_xla_sharding_base @@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self): fake_output = fake_capture.get() assert output == fake_output + +class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.environ["CONVERT_SHLO_TO_SHARDY"] = "1" + + def run_test(self): + mesh = self._get_mesh(self.device_mesh_shape) + t = torch.randn(self.tensor_shape).to(torch_xla.device()) + xs.mark_sharding(t, mesh, self.partition_spec) + actual_str = construct_v1_sharding_str(t) + self.assertEqual(self.expected_str, actual_str) + + def test_tiled_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (1, 128) + self.partition_spec = (0, 1) + self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 2, + f"Requires at least 2 devices.") + def test_tupled_tiled_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (16,) + self.partition_spec = ((0, 1),) + self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + self.run_test() + + def test_replicated_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (4, 4) + self.partition_spec = (None, None) + self.expected_str = '{replicated}' + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_partial_replication_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (4, 4) + self.partition_spec = (0, None) + self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_tupled_partial_replication_sharding(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = ((0, 1), None) + self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + def test_tupled_partial_replication_sharding_with_transpose(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = (None, (2, 1)) + device_order = self.device_ids.reshape(self.device_mesh_shape).transpose( + (2, 1, 0)).flatten() + self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + self.run_test() + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 48b760f6e3f..b74709d66f6 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): super().setUpClass() + cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY") def test_xla_sharded_tensor(self): partition_spec = (0, 1) @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in reversed(range(self.n_devices))])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_mark_sharding_2d(self): @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1)) actual = (xt1 + xt2).cpu() @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self): annotation = '{devices=[1,1,%d,%d]%s}' % ( z_dim, self.n_devices // z_dim, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices // + z_dim, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) actual = (xt + xt).cpu() @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16).to('xla') xs.mark_sharding(t, mesh, ((0, 1),)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() >= 4, "Multiple devices required for tupled partition spec") @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self): # Shard the first dimension on `r` and `b`, replicate the second dimension t = torch.randn(16, 16).to('xla') xs.mark_sharding(t, mesh, (('r', 'b'), None)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), - "{devices=[2,1,%d]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) # Replicate the first dimension, shard the second on `b` and `m` u = torch.randn(16, 16).to('xla') xs.mark_sharding(u, mesh, (None, ('b', 'm'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation) # Replicate the first dimension, shard the second on `r` and `m` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('r', 'm'))) device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), - "{devices=[1,%d,2]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) # Replicate the first dimension, shard the second on `m` and `b` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('m', 'b'))) device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices, + self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'Multiple devices required for tupled partition spec') @@ -452,9 +471,12 @@ def test_multiple_tuples_in_spec(self): ('a', 'b', 'c', 'd')) t = torch.randn(2, 2).to('xla') xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'At least 2 devices needed for 2D mesh') @@ -462,9 +484,12 @@ def test_3d_tensor_2d_mesh(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16, 16, 16).to('xla') xs.mark_sharding(t, mesh, (None, 0, 1)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) def test_partial_replication_addmm(self): device = torch_xla.device() @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self): t = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(t, mesh, (0, 1)) - self.assertIn("CreateOpSharding", met.counter_names()) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding" + self.assertIn(counter_name, met.counter_names()) + self.assertEqual(met.counter_value(counter_name), 1) # Sharding with the same partition spec should not result in another call u = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(u, mesh, (0, 1)) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + self.assertEqual(met.counter_value(counter_name), 1) - # Changing the partition spec will result in another CreateOpSharding + # Changing the partition spec will result in another + # CreateOpSharding or CreatingIotaOpSharding call v = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(v, mesh, (0, None)) - self.assertEqual(met.counter_value("CreateOpSharding"), 2) + self.assertEqual(met.counter_value(counter_name), 2) def test_from_cpu_shards_replicated(self): from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards @@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([8, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, @@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9ce45e8761a..7d0882fbe59 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -752,6 +752,16 @@ std::string GetTensorsHloGraph(const std::vector& tensors, return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } +std::optional GetXLAOpSharding(const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + return sharding_spec->sharding; + } + return std::nullopt; +} + std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { @@ -1517,6 +1527,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) { void InitXlaModuleBindings(py::module m) { PythonScope module(m); + using TileAssignmentDims = std::vector; + using ReshapeDims = std::vector; + using TransposePerm = std::vector; + // Define the _XLAC.XlaShardingSpec class. PythonScope>( m, "XlaShardingSpec") @@ -1525,24 +1539,22 @@ void InitXlaModuleBindings(py::module m) { const py::list& replication_groups, int sharding_type, bool minibatch) { xla::Shape global_shape = - CreateComputationShapeFromTensor(tensor, nullptr); - if (minibatch) { - int num_local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); - int num_global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); - XLA_CHECK(tile_assignment.size() == num_global_devices) - << "Minibatch sharding only supports sharding along the batch " - "dimension"; - int batch_dim_shape = - tensor.sizes()[0] * num_global_devices / num_local_devices; - global_shape.set_dimensions(0, batch_dim_shape); - } + ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), global_shape, minibatch); + }) + .def_init([](at::Tensor tensor, const py::list& dims, + const py::list& reshape_dims, const py::list& transpose_perm, + const py::list& types, bool minibatch) { + xla::Shape global_shape = + ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); + return std::make_shared( + ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, + transpose_perm, types), + global_shape, minibatch); }); // Define the _XLAC.IrValue class. @@ -1561,12 +1573,19 @@ void InitXlaModuleBindings(py::module m) { // Define the _XLAC.OpSharding class. PythonScope>(m, "OpSharding") + // Constructor for V1 shardings .def_init([](const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { return ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); + }) + // Constructor for V2 shardings. + .def_init([](const py::list& dims, const py::list& reshape_dims, + const py::list& transpose_perm, const py::list& types) { + return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, + transpose_perm, types); }); // Define the _XLAC.PjRtPlugin class. @@ -1792,7 +1811,8 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_get_runtime_devices", - []() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) + []() { + return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) .def("_xla_num_runtime_devices", []() -> int64_t { return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); @@ -2212,9 +2232,11 @@ void InitXlaModuleBindings(py::module m) { return device.ordinal(); }) .def("_xla_get_process_index", - []() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) + []() { + return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) .def("_xla_get_num_processes", - []() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) + []() { + return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) .def("_xla_get_num_cached_compilation_graph", []() -> int64_t { return XLAGraphExecutor::Get()->GetNumGraphHash(); @@ -2646,13 +2668,26 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_xla_op_sharding", [](const at::Tensor& input) -> std::optional { - XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); - XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; - if (sharding_spec != nullptr) { - return sharding_spec->sharding; + return GetXLAOpSharding(input); + }) + .def("_get_xla_op_sharding_v2_params", + [](const at::Tensor& input) -> std::optional> { + std::optional maybe_sharding = + GetXLAOpSharding(input); + if (!maybe_sharding) { + return std::nullopt; } - return std::nullopt; + const xla::OpSharding& sharding = maybe_sharding.value(); + TileAssignmentDims tile_assignment_dims( + sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end()); + ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(), + sharding.iota_reshape_dims().end()); + TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(), + sharding.iota_transpose_perm().end()); + return std::make_tuple(tile_assignment_dims, reshape_dims, + transpose_perm, + sharding.replicate_on_last_tile_dim()); }) .def("_get_xla_sharding_specs", [](const std::vector& tensors) diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b381d3feff7..0988ed79eaa 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -366,6 +366,7 @@ cc_library( "@xla//xla/mlir_hlo:all_passes", "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 280b50964d8..d0b552613d1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -14,6 +14,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" @@ -638,6 +639,9 @@ std::vector PjRtComputationClient::Compile( mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module); + if (runtime::sys_util::GetEnvBool("CONVERT_SHLO_TO_SHARDY", false)) { + ConvertStableHloToSdy(&mlir_module); + } executable = util::RaisePythonValueErrorOnFailure([&] { return fake_xla_compile_ ? fake_xla_compile_() diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cpp b/torch_xla/csrc/runtime/stablehlo_helper.cpp index 08856778fd8..857ec580917 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_helper.cpp @@ -18,6 +18,7 @@ #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" namespace torch_xla { @@ -89,6 +90,7 @@ static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module, torch_xla::runtime::CreateRemoveXlaMarkTensorOpsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); + if (!mlir::succeeded(pm.run(*mlir_module))) { return absl::Status( absl::StatusCode::kInternal, @@ -111,6 +113,14 @@ void ConvertHloToStableHlo(const xla::HloModuleProto* proto, << getHloModuleStr(proto); } +void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module) { + mlir::PassManager pm(mlir_module->getContext()); + xla::sdy::addStablehloImportPipeline(pm, false, false); + if (!mlir::succeeded(pm.run(*mlir_module))) { + XLA_ERROR() << "StableHLO -> SDY conversion failed.\n"; + } +} + std::string hloToStablehlo(const xla::HloModuleProto* proto, bool emit_bytecode) { mlir::MLIRContext context; diff --git a/torch_xla/csrc/runtime/stablehlo_helper.h b/torch_xla/csrc/runtime/stablehlo_helper.h index bdef7b97540..2298ecfb2d1 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.h +++ b/torch_xla/csrc/runtime/stablehlo_helper.h @@ -13,6 +13,8 @@ namespace torch_xla { std::string hloToStablehlo(const xla::HloModuleProto* proto, bool emit_bytecode); +void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module); + void ConvertHloToStableHlo(const xla::HloModuleProto* proto, mlir::ModuleOp* mlir_module); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 55c6ebf186f..25bbc302da7 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -218,6 +218,26 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, return xla::protobuf_util::HaveSameSerialization(a, b); } +xla::OpSharding ShardingUtil::CreateIotaOpSharding( + const py::list& dims, const py::list& reshape_dims, + const py::list& transpose_perm, const py::list& types) { + TORCH_LAZY_COUNTER("CreateIotaOpSharding", 1); + auto dims_vec = dims.cast>(); + auto reshape_dims_vec = reshape_dims.cast>(); + auto transpose_perm_vec = transpose_perm.cast>(); + std::vector subgroup_types_vec; + for (auto type : types) { + subgroup_types_vec.push_back( + static_cast(type.cast())); + } + CHECK_EQ(reshape_dims_vec.size(), transpose_perm_vec.size()); + return xla::HloSharding::Subgroup( + xla::TileAssignment(dims_vec, reshape_dims_vec, + transpose_perm_vec), + subgroup_types_vec) + .ToProto(); +} + xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { @@ -865,4 +885,20 @@ bool ShardingUtil::GetAutoSharding() { } return use_auto_sharding; } + +xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor, + bool minibatch) { + xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); + if (minibatch) { + int num_local_devices = + runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + int num_global_devices = + runtime::GetComputationClientOrDie()->GetAllDevices().size(); + int batch_dim_shape = + tensor.sizes()[0] * num_global_devices / num_local_devices; + global_shape.set_dimensions(0, batch_dim_shape); + } + return global_shape; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b8b98653b2..a925c470748 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -51,6 +51,12 @@ class ShardingUtil { const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type); + // Creates an xla::OpSharding for TILED and PARTIAL types using the + // HloShardingV2 system. + static xla::OpSharding CreateIotaOpSharding(const py::list& dims, + const py::list& reshape_dims, + const py::list& transpose_perm, + const py::list& types); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all @@ -150,6 +156,9 @@ class ShardingUtil { static void SetAutoSharding(); static bool GetAutoSharding(); + + static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor, + bool minibatch); }; } // namespace torch_xla diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index e5f53d04aea..2cb9368aff0 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -157,6 +157,27 @@ def visualize_sharding(sharding: str, return table +def construct_v1_sharding_str(t: torch.Tensor) -> str: + """ + Returns the corresponding HLO V1 sharding string from the tensor + """ + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + if "<=" not in sharding: + # This is already in the V1 format + return sharding + sharding_params = torch_xla._XLAC._get_xla_op_sharding_v2_params(t) + assert sharding_params is not None + tile_assignment_dims, reshape_dims, transpose_perm, replicate_on_last_dim = sharding_params + num_devices = np.prod(reshape_dims) + device_list = np.arange(num_devices).reshape(reshape_dims).transpose( + transpose_perm).reshape(num_devices) + + tile_assignment_str = ",".join(str(dim) for dim in tile_assignment_dims) + device_list_str = ",".join(str(i) for i in device_list) + replicate_str = " last_tile_dim_replicate" if replicate_on_last_dim else "" + return f"{{devices=[{tile_assignment_str}]{device_list_str}{replicate_str}}}" + + def visualize_tensor_sharding(t, **kwargs): """Visualizes an array's sharding.""" @@ -164,5 +185,7 @@ def visualize_tensor_sharding(t, **kwargs): def maybe_unwrap(t: torch.Tensor) -> torch.Tensor: return t.global_tensor if isinstance(t, XLAShardedTensor) else t - sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t)) + t = maybe_unwrap(t) + sharding = construct_v1_sharding_str(t) + return visualize_sharding(sharding, **kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index c010fd4c352..363658f77dd 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -1,6 +1,7 @@ import collections from collections.abc import Generator, MutableMapping import math +import os from collections import OrderedDict, defaultdict from dataclasses import dataclass, field import torch @@ -118,9 +119,7 @@ def get_axis_name_idx(self, name: str) -> int: return None return self.axis_names.index(name) - @functools.lru_cache(maxsize=None) - def _get_op_sharding_args(self, partition_spec: PartitionSpec): - partition_spec = _translate_named_partition_spec(self, partition_spec) + def _validate_translated_partition_spec(self, partition_spec: tuple): flat_specs = np.hstack([d for d in partition_spec]) specs = [d for d in flat_specs if d is not None] assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ @@ -128,6 +127,11 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + @functools.lru_cache(maxsize=None) + def _get_op_sharding_args(self, partition_spec: PartitionSpec): + partition_spec = _translate_named_partition_spec(self, partition_spec) + self._validate_translated_partition_spec(partition_spec) + tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): # Use partial replication for sharding a tensor over a higher-rank mesh @@ -142,6 +146,77 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): sharding_type = int(sharding_type) return tile_assignment, group_assignment, replication_groups, sharding_type + @functools.lru_cache(maxsize=None) + def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): + """ + This function returns all the sharding parameters needed for TILED or PARTIAL sharding. + (All other sharding types are handled separately by the V1 OpSharding function) + """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + self._validate_translated_partition_spec(partition_spec) + + # This algorithm is adapted from + # https://github.com/openxla/xla/blob/256b633e0adaee80588a8c3a5e4b2eaa005b5414/xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.cc#L288 + tile_assignment_dims = [1] * len(partition_spec) + axisRefToShardedPos = {} + subgroup_types = [] + shardedPos = 0 + + for idx, axes in enumerate(partition_spec): + if axes is None: + # Tensor dim is being replicated + continue + elif isinstance(axes, tuple): + # Tensor dim is being sharded over multiple axes + for axis in axes: + tile_assignment_dims[idx] *= self.mesh_shape[axis] + axisRefToShardedPos[axis] = shardedPos + shardedPos += 1 + else: + # Tensor dim is being sharded over just 1 axis + tile_assignment_dims[idx] *= self.mesh_shape[axes] + axisRefToShardedPos[axes] = shardedPos + shardedPos += 1 + + all_axes_ordered = [i for i in range(len(self.mesh_shape))] + reshape_dims = [0] * len(all_axes_ordered) + transpose_perm = [0] * len(all_axes_ordered) + + totalReplicatedSize = 1 + replicatedPos = shardedPos + for idx, axis in enumerate(all_axes_ordered): + reshape_dims[idx] = self.mesh_shape[axis] + if axis in axisRefToShardedPos: + # Axis is sharded + transpose_perm[axisRefToShardedPos[axis]] = idx + else: + # Axis is replicated + transpose_perm[replicatedPos] = idx + replicatedPos += 1 + totalReplicatedSize *= self.mesh_shape[axis] + + if totalReplicatedSize > 1: + tile_assignment_dims.append(totalReplicatedSize) + subgroup_types.append(ShardingType.REPLICATED) + + return tile_assignment_dims, reshape_dims, transpose_perm, subgroup_types + + @functools.lru_cache(maxsize=None) + def get_op_sharding_v2( + self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding: + """ + Return the OpSharding for the given partition spec using V2 annotations. + """ + if len(partition_spec) == 0: + return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED) + sharding_type = _get_sharding_type(partition_spec, self.size()) + if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL): + return torch_xla._XLAC.OpSharding([], [], [], sharding_type) + + dims, reshape_dims, transpose_perm, types = self._get_op_sharding_args_v2( + partition_spec) + return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm, types) + @functools.lru_cache(maxsize=None) def get_op_sharding( self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding: @@ -157,6 +232,7 @@ def get_op_sharding( tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( partition_spec) + return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) @@ -530,6 +606,11 @@ def _mark_manual_sharding( return wrap_as_sharded_tensor(t) +def _use_shlo_to_shardy() -> bool: + return os.environ.get("CONVERT_SHLO_TO_SHARDY", + "").lower() in ("1", "true", "yes") + + def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], partition_spec: PartitionSpec, *, @@ -653,7 +734,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, t.shard_(NamedSharding(jmesh, P(*partition_spec))) return t - op_sharding = mesh.get_op_sharding(partition_spec) + if _use_shlo_to_shardy(): + op_sharding = mesh.get_op_sharding_v2(partition_spec) + else: + op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) # Pass mesh and partition spec information for DTensor compatibility @@ -832,6 +916,9 @@ def __post_init__(self): self._group_assignment, self._replication_groups = _get_group_assignment( self._sharding_type, tile_assignment, len(partition_spec), replicate_dims) + if _use_shlo_to_shardy(): + self.dims, self.reshape_dims, self.transpose_perm, self.subgroup_types = mesh._get_op_sharding_args_v2( + partition_spec) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ @@ -840,6 +927,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ if not self.can_apply(t): return None + + if _use_shlo_to_shardy(): + return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims, + self.transpose_perm, + self.subgroup_types, + self.minibatch) + return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, self._group_assignment, self._replication_groups,