|
| 1 | +from copy import copy |
1 | 2 | import os
|
2 | 3 | import sys
|
3 | 4 | import unittest
|
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import torch
|
| 8 | +import torch.nn as nn |
7 | 9 | import torch_xla
|
8 | 10 | import torch_xla.runtime as xr
|
9 |
| -from torch_xla.experimental.assume_pure import assume_pure |
| 11 | +from torch_xla.experimental.assume_pure import PureModule, assume_pure |
10 | 12 | from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, Mesh
|
| 13 | +from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear |
| 14 | + |
| 15 | + |
| 16 | +def get_2d_mesh(name1: str, name2: str): |
| 17 | + num_devices = xr.global_runtime_device_count() |
| 18 | + dim1_size = 2 |
| 19 | + assert num_devices % 2 == 0 |
| 20 | + dim2_size = num_devices // dim1_size |
| 21 | + devices = np.arange(xr.global_runtime_device_count()) |
| 22 | + mesh_shape = (dim1_size, dim2_size) |
| 23 | + return Mesh(devices, mesh_shape=mesh_shape, axis_names=(name1, name2)) |
11 | 24 |
|
12 | 25 |
|
13 | 26 | class AssumePureSpmdTest(unittest.TestCase):
|
@@ -56,6 +69,44 @@ def test_assume_pure_works_with_mark_sharding_with_gradients(self):
|
56 | 69 | self.assertIn(f'devices=[{N}',
|
57 | 70 | torch_xla._XLAC._get_xla_sharding_spec(x.grad))
|
58 | 71 |
|
| 72 | + @unittest.skipUnless(xr.global_runtime_device_count() > 1, |
| 73 | + "Multiple devices required") |
| 74 | + @unittest.skipIf( |
| 75 | + torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', |
| 76 | + "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" |
| 77 | + ) |
| 78 | + def test_assume_pure_works_with_mark_sharding_2d(self): |
| 79 | + mesh = get_2d_mesh("model", "batch") |
| 80 | + set_global_mesh(mesh) |
| 81 | + x = torch.randn((8, 4, 5, 128), device='xla') |
| 82 | + result = assume_pure(mark_sharding)(x, mesh, |
| 83 | + (("model", "batch"), None, None, None)) |
| 84 | + torch_xla.sync(wait=True) |
| 85 | + N = xr.global_runtime_device_count() |
| 86 | + self.assertIn(f'devices=[{N}', |
| 87 | + torch_xla._XLAC._get_xla_sharding_spec(result)) |
| 88 | + |
| 89 | + @unittest.skipUnless(xr.global_runtime_device_count() > 1, |
| 90 | + "Multiple devices required") |
| 91 | + @unittest.skipIf( |
| 92 | + torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', |
| 93 | + "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" |
| 94 | + ) |
| 95 | + def test_assume_pure_works_with_mark_sharding_with_gradients_2d(self): |
| 96 | + mesh = get_2d_mesh("model", "batch") |
| 97 | + set_global_mesh(mesh) |
| 98 | + x = torch.randn((8, 4, 5, 128)).to('xla').requires_grad_(True) |
| 99 | + result = assume_pure(mark_sharding_with_gradients)( |
| 100 | + x, mesh, (("model", "batch"), None, None, None)) |
| 101 | + result.sum().backward() |
| 102 | + torch_xla.sync(wait=True) |
| 103 | + N = xr.global_runtime_device_count() |
| 104 | + self.assertIn(f'devices=[{N}', |
| 105 | + torch_xla._XLAC._get_xla_sharding_spec(result)) |
| 106 | + assert x.grad is not None |
| 107 | + self.assertIn(f'devices=[{N}', |
| 108 | + torch_xla._XLAC._get_xla_sharding_spec(x.grad)) |
| 109 | + |
59 | 110 | @unittest.skipUnless(xr.global_runtime_device_count() > 1,
|
60 | 111 | "Multiple devices required")
|
61 | 112 | @unittest.skipIf(
|
@@ -94,6 +145,33 @@ def test_convert_to_jax_mesh_shuffled(self):
|
94 | 145 | np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
|
95 | 146 | )
|
96 | 147 |
|
| 148 | + @unittest.skipUnless(xr.global_runtime_device_count() > 1, |
| 149 | + "Multiple devices required") |
| 150 | + @unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test") |
| 151 | + def test_pure_module(self): |
| 152 | + """Test tracing `nn.Linear` and `EinsumLinear` with `assume_pure`.""" |
| 153 | + for transform in [apply_xla_patch_to_nn_linear, lambda x: x]: |
| 154 | + with torch_xla.device(): |
| 155 | + # Arrange |
| 156 | + original = nn.Linear(4, 8) |
| 157 | + replaced = PureModule(transform(copy(original))) |
| 158 | + inputs = torch.ones((4,)) |
| 159 | + torch_xla.sync() |
| 160 | + |
| 161 | + # Act |
| 162 | + original_output = original(inputs) |
| 163 | + original_output.sum().backward() |
| 164 | + replaced_output = replaced(inputs) |
| 165 | + replaced_output.sum().backward() |
| 166 | + torch_xla.sync() |
| 167 | + |
| 168 | + # Assert |
| 169 | + torch.testing.assert_close(original_output, replaced_output) |
| 170 | + torch.testing.assert_close(original.weight.grad, |
| 171 | + replaced._module.weight.grad) |
| 172 | + torch.testing.assert_close(original.bias.grad, |
| 173 | + replaced._module.bias.grad) |
| 174 | + |
97 | 175 |
|
98 | 176 | if __name__ == '__main__':
|
99 | 177 | test = unittest.main()
|
|
0 commit comments