Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
from torch_xla.distributed.spmd import Mesh
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str

import test_xla_sharding_base

Expand Down Expand Up @@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output


class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"

def run_test(self):
mesh = self._get_mesh(self.device_mesh_shape)
t = torch.randn(self.tensor_shape).to(torch_xla.device())
xs.mark_sharding(t, mesh, self.partition_spec)
actual_str = construct_v1_sharding_str(t)
self.assertEqual(self.expected_str, actual_str)

def test_tiled_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (1, 128)
self.partition_spec = (0, 1)
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 2,
f"Requires at least 2 devices.")
def test_tupled_tiled_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (16,)
self.partition_spec = ((0, 1),)
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
self.run_test()

def test_replicated_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (4, 4)
self.partition_spec = (None, None)
self.expected_str = '{replicated}'
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_partial_replication_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (4, 4)
self.partition_spec = (0, None)
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_tupled_partial_replication_sharding(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = ((0, 1), None)
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

def test_tupled_partial_replication_sharding_with_transpose(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = (None, (2, 1))
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
(2, 1, 0)).flatten()
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
self.run_test()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
99 changes: 63 additions & 36 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")

def test_xla_sharded_tensor(self):
partition_spec = (0, 1)
Expand Down Expand Up @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in reversed(range(self.n_devices))]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_mark_sharding_2d(self):
Expand All @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))

actual = (xt1 + xt2).cpu()
Expand All @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
annotation = '{devices=[1,1,%d,%d]%s}' % (
z_dim, self.n_devices // z_dim, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
z_dim, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

actual = (xt + xt).cpu()
Expand Down Expand Up @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16).to('xla')
xs.mark_sharding(t, mesh, ((0, 1),))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for tupled partition spec")
Expand All @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
# Shard the first dimension on `r` and `b`, replicate the second dimension
t = torch.randn(16, 16).to('xla')
xs.mark_sharding(t, mesh, (('r', 'b'), None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t),
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

# Replicate the first dimension, shard the second on `b` and `m`
u = torch.randn(16, 16).to('xla')
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)

# Replicate the first dimension, shard the second on `r` and `m`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v),
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

# Replicate the first dimension, shard the second on `m` and `b`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'Multiple devices required for tupled partition spec')
Expand All @@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
('a', 'b', 'c', 'd'))
t = torch.randn(2, 2).to('xla')
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'At least 2 devices needed for 2D mesh')
def test_3d_tensor_2d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16, 16, 16).to('xla')
xs.mark_sharding(t, mesh, (None, 0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

def test_partial_replication_addmm(self):
device = torch_xla.device()
Expand Down Expand Up @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):

t = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(t, mesh, (0, 1))
self.assertIn("CreateOpSharding", met.counter_names())
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
self.assertIn(counter_name, met.counter_names())
self.assertEqual(met.counter_value(counter_name), 1)

# Sharding with the same partition spec should not result in another call
u = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(u, mesh, (0, 1))
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
self.assertEqual(met.counter_value(counter_name), 1)

# Changing the partition spec will result in another CreateOpSharding
# Changing the partition spec will result in another
# CreateOpSharding or CreatingIotaOpSharding call
v = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
self.assertEqual(met.counter_value(counter_name), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
Expand Down Expand Up @@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand All @@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand Down
Loading
Loading