diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index f7671cc3d82..3472a073ea3 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -256,6 +256,7 @@ function run_xla_op_tests3 { #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index b2cc8f751d2..dabcdb83961 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -255,6 +255,7 @@ function run_xla_op_tests3 { run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_dtensor_from_local.py b/test/spmd/test_xla_dtensor_from_local.py new file mode 100644 index 00000000000..40647e20590 --- /dev/null +++ b/test/spmd/test_xla_dtensor_from_local.py @@ -0,0 +1,149 @@ +import sys +import unittest +import torch +import numpy as np + +from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor +import test_xla_sharding_base + + +class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): + """ + Test suite for the automatic conversion of regular tensors to XLAShardedTensor + in DTensor.from_local() when using XLA device mesh. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_basic_conversion(self): + """Test basic conversion of regular tensor to XLAShardedTensor.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with the regular tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Basic conversion successful") + + + def test_conversion_with_placements(self): + """Test conversion with explicit placements.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with explicit placements + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Replicate()] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with placements successful") + + def test_conversion_with_sharding(self): + """Test conversion with sharding placement.""" + world_size = xr.global_runtime_device_count() + if world_size < 2: + self.skipTest("Need at least 2 devices for sharding test") + + # Create a tensor divisible by world_size + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with sharding placement + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Shard(0)] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with sharding successful") + + def test_conversion_with_different_dtypes(self): + """Test conversion with different dtypes.""" + world_size = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Test with different dtypes + for dtype in [torch.float16, torch.float32, torch.int32, torch.int64]: + # Create a tensor with specific dtype + tensor = torch.ones(100_000, 88, dtype=dtype) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Use DTensor.from_local with the tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify dtype is preserved + self.assertEqual(dt.dtype, dtype) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + if dtype.is_floating_point: + result = dt + 1.0 + else: + result = dt + 1 + + self.assertEqual(result.shape, tensor.shape) + self.assertEqual(result.dtype, dtype) + + print(f"Conversion with {dtype} successful") + + +if __name__ == "__main__": + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 24f18d3bdcd..ec585716cb6 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py"