diff --git a/API_GUIDE.md b/API_GUIDE.md index f2e9fc1cc2dd..f13ff86ab10b 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -22,7 +22,7 @@ print(t) This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and -`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU +`torch.device('xla')` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -112,7 +112,7 @@ train_loader = xu.SampleGenerator( torch.zeros(batch_size, dtype=torch.int64)), sample_count=60000 // batch_size // xr.world_size()) -device = torch_xla.device() # Get the XLA device (TPU). +device = torch.device('xla') # Get the XLA device (TPU). model = MNIST().train().to(device) # Create a model and move it to the device. loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -169,7 +169,7 @@ def _mp_fn(index): index: Index of the process. """ - device = torch_xla.device() # Get the device assigned to this process. + device = torch.device('xla') # Get the device assigned to this process. # Wrap the loader for multi-device. mp_device_loader = pl.MpDeviceLoader(train_loader, device) @@ -290,7 +290,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index b784af68e47b..9c5867cbfc9a 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment, def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment, benchmark_model: BenchmarkModel, input_tensor): - device = torch_xla.device() if benchmark_experiment.xla else 'cuda' + device = torch.device('xla') if benchmark_experiment.xla else 'cuda' sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize timing, output = bench.do_bench( lambda: benchmark_model.model_iter_fn( diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 8d4fbd95bff7..c829c4b9a36f 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -193,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:28.607393Z", @@ -210,7 +210,7 @@ "lock = mp.Manager().Lock()\n", "\n", "def print_device(i, lock):\n", - " device = torch_xla.device()\n", + " device = torch.device('xla')\n", " with lock:\n", " print('process', i, device)" ] @@ -454,7 +454,7 @@ "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n", "\n", "def toy_model(index, lock):\n", - " device = torch_xla.device()\n", + " device = torch.device('xla')\n", " dist.init_process_group('xla', init_method='xla://')\n", "\n", " # Initialize a basic toy model\n", diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index edaa56ecee72..38cd322e7940 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') - dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) + dist.init_process_group('xla', init_method='xla://') diff --git a/docs/source/learn/eager.md b/docs/source/learn/eager.md index 0d82ae3c581c..cbf54d3a6c32 100644 --- a/docs/source/learn/eager.md +++ b/docs/source/learn/eager.md @@ -13,7 +13,7 @@ import torch import torch_xla import torchvision -device = torch_xla.device() +device = torch.device('xla') model = torchvision.models.resnet18().to(device) input = torch.randn(64, 3, 224, 224).to(device) @@ -71,7 +71,7 @@ import torchvision # Run ops eagerly by default torch_xla.experimental.eager_mode(True) -device = torch_xla.device() +device = torch.device('xla') model = torchvision.models.resnet18().to(device) # Mark the function to be compiled diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index c0b48bec1813..0be3ce038e5f 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -21,7 +21,7 @@ print(t) This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes -PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This +PyTorch/XLA, and `torch.device('xla')` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device: ``` python import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) @@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') mp_device_loader = pl.MpDeviceLoader(train_loader, device) model = MNIST().train().to(device) @@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one. will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `torch_xla.device()` on each process you + - Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 @@ -283,7 +283,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index f6b0761fd69a..e74247c2fb88 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models. General guidelines to modify your code: -- Replace `cuda` with `torch_xla.device()` +- Replace `cuda` with `torch.device('xla')` - Remove progress bar, printing that would access the XLA tensor values - Reduce logging and callbacks that would access the XLA tensor values @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well. ``` python import torch_xla.core.xla_model as xm - self.device = torch_xla.device() + self.device = torch.device('xla') ``` Another place in the code that has cuda specific code is DDIM scheduler. @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"): with ``` python -device = torch_xla.device() +device = torch.device('xla') attr = attr.to(torch.device(device)) ``` @@ -339,7 +339,7 @@ with the following lines: ``` python import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') pipe.to(device) ``` diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index 4ad48753d45c..0d0db54f1682 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -27,7 +27,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(input) loss = loss_fn(output, target) @@ -36,7 +36,7 @@ for input, target in data: xm.optimizer_step.(optimizer) ``` -`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA +`autocast(torch.device('xla'))` aliases `torch.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. @@ -115,7 +115,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(input) loss = loss_fn(output, target) @@ -127,12 +127,12 @@ for input, target in data: scaler.update() ``` -`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the +`autocast(torch.device('xla'))` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used, but requires `torch` is compiled with `cuda` support for datatype of `torch.bfloat16`. We recommend using -`autocast(torch_xla.device())` on XLA:GPU as it does not require +`autocast(torch.device('xla'))` on XLA:GPU as it does not require `torch.cuda` support for any datatypes, including `torch.bfloat16`. ### AMP for XLA:GPU Best Practices diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index efc4071d648d..51067d37044a 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -105,7 +105,7 @@ def demo_basic(rank): setup(rank, world_size) # create model and move it to XLA device - device = torch_xla.device() + device = torch.device('xla') model = ToyModel().to(device) ddp_model = DDP(model, gradient_as_bucket_view=True) diff --git a/docs/source/perf/dynamo.md b/docs/source/perf/dynamo.md index 090decb77371..2ab3982fe820 100644 --- a/docs/source/perf/dynamo.md +++ b/docs/source/perf/dynamo.md @@ -41,7 +41,7 @@ import torchvision import torch_xla.core.xla_model as xm def eval_model(loader): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.eval() dynamo_resnet18 = torch.compile( @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer): return pred def train_model_main(loader): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.train() dynamo_train_model = torch.compile( diff --git a/docs/source/perf/fori_loop.md b/docs/source/perf/fori_loop.md index bfdd2bf318ab..b6ebf57e09a8 100644 --- a/docs/source/perf/fori_loop.md +++ b/docs/source/perf/fori_loop.md @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init) >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm >>> ->>> device = torch_xla.device() +>>> device = torch.device('xla') >>> >>> def cond_fn(iteri, x): ... return iteri > 0 @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times: >>> import torch_xla >>> import torch_xla.core.xla_model as xm >>> ->>> device = torch_xla.device() +>>> device = torch.device('xla') >>> >>> init_val = torch.tensor(1, device=device) >>> iteri = torch.tensor(50, device=device) diff --git a/docs/source/perf/quantized_ops.md b/docs/source/perf/quantized_ops.md index 6d44b05e433b..8aa9ed063dc0 100644 --- a/docs/source/perf/quantized_ops.md +++ b/docs/source/perf/quantized_ops.md @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16) # Call with torch CPU tensor (For debugging purpose) matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler) -device = torch_xla.device() +device = torch.device('xla') x_xla = x.to(device) w_int_xla = w_int.to(device) scaler_xla = scaler.to(device) diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index b3b3a33590e9..ae6efea7d079 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -35,7 +35,7 @@ def __init__(self, torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)), sample_count=self.train_dataset_len // self.batch_size) - self.device = torch_xla.device() + self.device = torch.device('xla') self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) self.model = decoder_cls(self.config).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index f5ca308bed75..7b0b68a10da2 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -19,7 +19,7 @@ def train_loop_fn(self, loader, epoch): for step, (data, target) in enumerate(loader): self.optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = self.model(data) loss = self.loss_fn(output, target) # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index c4a8890e9be7..59ff180934f1 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -28,7 +28,7 @@ def __init__(self): sample_count=self.train_dataset_len // self.batch_size // xr.world_size()) - self.device = torch_xla.device() + self.device = torch.device('xla') self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) self.model = torchvision.models.resnet50().to(self.device) self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md index 76c9d0b7c88e..d3771094768d 100644 --- a/plugins/cpu/README.md +++ b/plugins/cpu/README.md @@ -38,5 +38,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) xr.set_device_type('CPU') -print(torch_xla.device()) +print(torch.device('xla')) ``` diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md index 45a002e06f6c..d3760610046c 100644 --- a/plugins/cuda/README.md +++ b/plugins/cuda/README.md @@ -35,5 +35,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) xr.set_device_type('CUDA') -print(torch_xla.device()) +print(torch.device('xla')) ``` diff --git a/test/bench.py b/test/bench.py index e5eff86a34d5..bb68dcda052e 100644 --- a/test/bench.py +++ b/test/bench.py @@ -29,7 +29,7 @@ class BaseBench(object): def __init__(self, args): self.args = args - self.device = torch_xla.device() + self.device = torch.device('xla') self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42) diff --git a/test/debug_tool/test_mp_pt_xla_debug.py b/test/debug_tool/test_mp_pt_xla_debug.py index 785554657b14..baf58cea6dfd 100644 --- a/test/debug_tool/test_mp_pt_xla_debug.py +++ b/test/debug_tool/test_mp_pt_xla_debug.py @@ -16,7 +16,7 @@ def _mp_fn(index): assert False, "This test should be run with PT_XLA_DEBUG_FILE" if index == 0: open(debug_file_name, 'w').close() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(10, 10, device=device) t2 = t1 * 100 torch_xla.sync() diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index 4ebcb2cd1bb9..54abfb98a3b5 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -31,7 +31,7 @@ def setUpClass(cls): def test_eager_sync(self): with torch_xla.experimental.eager_mode_context(True): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 9, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -41,7 +41,7 @@ def test_eager_sync(self): open(self.debug_file_name, 'w').close() def test_user_sync(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(2, 2, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -79,7 +79,7 @@ def test_user_sync(self): open(self.debug_file_name, 'w').close() def test_step_trace(self): - device = torch_xla.device() + device = torch.device('xla') with xp.StepTrace('train_pt_xla_debug'): t1 = torch.randn(3, 3, device=device) with open(self.debug_file_name, 'rb') as f: @@ -111,7 +111,7 @@ def test_step_trace(self): open(self.debug_file_name, 'w').close() def test_dynamo(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4, device=device) def toy_program(t1): @@ -161,7 +161,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(12, 4, device=device) def toy_program(t1): @@ -209,7 +209,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile_custom_name(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(18, 4, device=device) def toy_program2(t1): @@ -239,7 +239,7 @@ def toy_program2(t1): open(self.debug_file_name, 'w').close() def test_parallel_loader(self): - device = torch_xla.device() + device = torch.device('xla') train_dataset_len = 100 batch_size = 10 @@ -287,7 +287,7 @@ def test_parallel_loader(self): open(self.debug_file_name, 'w').close() def test_print(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) print(t1) with open(self.debug_file_name, 'rb') as f: @@ -315,7 +315,7 @@ def test_print(self): open(self.debug_file_name, 'w').close() def test_frame(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(6, 6, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: diff --git a/test/distributed_util.py b/test/distributed_util.py index 85069aaabc82..32f04712575e 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -101,7 +101,7 @@ def ddp_correctness(init_method: str = 'env://', dist.init_process_group("xla", init_method=init_method) rank, world_size = dist.get_rank(), dist.get_world_size() - device = torch_xla.device() + device = torch.device('xla') # Module initialization is not thread safe. Force threads to initialize one # at a time with the same seed diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index 114c41e5c829..7f5e50a838d8 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -17,7 +17,7 @@ # It enables us to run python implementations of CompositeAutogradImplicit ops. # CompositeAutogradImplicit means we don't have an explicit backward formula for an op instead an op is composed of a bunch of ops that do have backward formulas and combines this formulas is equivalent to differentiating the op explicitly. pd = torch._C._EnablePythonDispatcher() -xla_dev = torch_xla.device() +xla_dev = torch.device('xla') class Feedforward(torch.nn.Module): diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 46f329de4537..2d9e4d5bb7a2 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -10,7 +10,7 @@ import test_utils pd = torch._C._EnablePythonDispatcher() -dev = torch_xla.device() +dev = torch.device('xla') class TestDynamicShapes(test_utils.XlaTestCase): @@ -192,7 +192,7 @@ def test_nonzero_cast(self): torch_xla.sync() def test_expand_symint_correctness(self): - dev = torch_xla.device() + dev = torch.device('xla') size1 = 5 size2 = 2 t1 = torch.ones([size1, size2]) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 5aa57abd3575..feb5898d9d80 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -116,7 +116,7 @@ def unwrap(cont): def make_reuse_graph_test(module_class, niter=100): def test_wrapper(self): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') xla_module = module_class().to(device=xla_dev) inputs = tuple(x.to(device=xla_dev) for x in xla_module.get_random_inputs()) metrics.clear_counters() @@ -187,7 +187,7 @@ def make_training_test(model_cls): def test_wrapper(self): import torch_xla.core.xla_model as xm - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = model_cls() inputs = model.get_random_inputs() @@ -240,7 +240,7 @@ class Emb(torch.nn.Embedding): def __init__(self): super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) - device = torch_xla.device() + device = torch.device('xla') module = Emb() module.to(device) @@ -255,7 +255,7 @@ def test_inputs_not_computed(self): def foo(x): return x * 2 - device = torch_xla.device() + device = torch.device('xla') x = torch.rand(5, device=device) x = x.unsqueeze(dim=-1) self._compile_and_check(foo, (x,)) @@ -265,7 +265,7 @@ def test_factory_copy(self): def foo(device): return torch.arange(5, device="cpu").to(device) - self._compile_and_check(foo, (torch_xla.device(),)) + self._compile_and_check(foo, (torch.device('xla'),)) def test_index_flag_unsupported(self): # The indices of the index operation are represented as @@ -277,7 +277,7 @@ def test_index_flag_unsupported(self): def foo(xt, t): return xt[t] - device = torch_xla.device() + device = torch.device('xla') xt = torch.rand(5, device=device) t = torch.randint(0, 5, (3,)) self._compile_and_check(foo, (xt, t)) @@ -299,7 +299,7 @@ def test_cpu_flag_unsupported(self): def foo(t): return t.cpu() - device = torch_xla.device() + device = torch.device('xla') t = torch.randint(0, 5, (3,), device=device) self._compile_and_check(foo, (t,)) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 572d255514a6..d7c55c1c6405 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -89,7 +89,7 @@ def test_sync_after_dynamo(self): head_dim = 128 running = 16 - device = torch_xla.device() + device = torch.device('xla') cache = torch.rand((cache_len, kv_heads, head_dim)).to(device) update_indices = torch.randint( 0, cache_len, (running,), dtype=torch.long).to(device) @@ -116,7 +116,7 @@ def copy_a_to_b(a): copy = torch.ops.aten.copy_.default(a, res) return copy - device = torch_xla.device() + device = torch.device('xla') compiled_copy = torch.compile(copy_a_to_b, backend=backend) a = torch.randn(2, 9).to(device) res = compiled_copy(a) @@ -150,7 +150,7 @@ def fn_simple(self, x, y): def _choose_proper_device(self, initialize_on_cuda): if not initialize_on_cuda: - return torch_xla.device() + return torch.device('xla') assert initialize_on_cuda if xr.device_type() != "CUDA" or not torch.cuda.is_available(): @@ -164,7 +164,7 @@ def _choose_proper_device(self, initialize_on_cuda): @skipOnNeuron def test_simple_model(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.tensor(100.0) y = torch.tensor(200.0) xla_x = x.to(device) @@ -413,7 +413,7 @@ def test_resnet18(self, initialize_on_cuda, backend): @skipOnNeuron def test_resnet18_lazy_vs_dynamo(self): sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) - device = torch_xla.device() + device = torch.device('xla') loader = self.get_loader(device, sample_count) resnet18_base = torchvision.models.resnet18() resnet18_base.eval() @@ -448,7 +448,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = torch_xla.device() + device = torch.device('xla') # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -488,7 +488,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = torch_xla.device() + device = torch.device('xla') # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -541,7 +541,7 @@ def train_model(self, model, data, target): def test_simple_model(self): torch._dynamo.reset() - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(3, 5, requires_grad=True) xla_input = input.detach().to(device) xla_input.requires_grad = True @@ -577,7 +577,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -650,7 +650,7 @@ def train_model(self, model, data, target, optimizer): def test_simple_model(self): torch._dynamo.reset() - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(3, 5, requires_grad=True) saved_input = input.detach().to(device).cpu() xla_input = input.detach().to(device) @@ -673,7 +673,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -732,7 +732,7 @@ def test_resnet18(self): class DynamoErrorMessageTest(parameterized.TestCase): def test_mixed_cpu_tensor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(4, 3, 224, 224) input_xla = input.clone().to(device) resnet18 = torchvision.models.resnet18() diff --git a/test/dynamo/test_dynamo_aliasing.py b/test/dynamo/test_dynamo_aliasing.py index 36bfb5744bd4..709186bec02c 100644 --- a/test/dynamo/test_dynamo_aliasing.py +++ b/test/dynamo/test_dynamo_aliasing.py @@ -11,7 +11,7 @@ class TestBufferDonationUtil(unittest.TestCase): def test_hash_with_buffer_donor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) res = torch.cos(input) hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) @@ -40,7 +40,7 @@ def dummy_mul(self, input): return input * 1.1 def test_manual_buffer_donation(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_mul_compiled = torch.compile( @@ -55,7 +55,7 @@ def test_manual_buffer_donation(self): torch.allclose(input_cloned.cpu() * 1.1, input.cpu()) def test_manual_buffer_donation_for_non_inplce_op(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_mul_compiled = torch.compile(self.dummy_mul, backend='openxla') @@ -81,7 +81,7 @@ def dummy_inplace(input): torch.ops.xla.dynamo_set_buffer_donor_(input, True) input += (0.5 * torch.sin(input)) - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -109,7 +109,7 @@ def dummy_add(self, input): return input + 1 def test_manual_buffer_donation(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile( @@ -127,7 +127,7 @@ def test_manual_buffer_donation(self): self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) def test_manual_buffer_donation_for_non_inplce_op(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') @@ -152,7 +152,7 @@ def test_manual_buffer_donation_for_inplce_op_repeat(self): def dummy_inplace(input): input += (0.3 * torch.cos(input)) - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -174,7 +174,7 @@ def dummy_inplace(input): self.assertEqual(met.metric_data('CompileTime')[0], 1) def test_buffer_donation_on_non_data_tensor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) res = input + 1 diff --git a/test/dynamo/test_dynamo_config.py b/test/dynamo/test_dynamo_config.py index 66f21cc84e91..d67c32acd475 100644 --- a/test/dynamo/test_dynamo_config.py +++ b/test/dynamo/test_dynamo_config.py @@ -10,7 +10,7 @@ def dummy_test(self, a): return a.cos().sin() def test_config_skip_input_data_check(self): - device = torch_xla.device() + device = torch.device('xla') print(config.skip_input_data_check) config.skip_input_data_check = True compiled_dummy = torch.compile(self.dummy_test, backend="openxla") diff --git a/test/dynamo/test_dynamo_dynamic_shape.py b/test/dynamo/test_dynamo_dynamic_shape.py index 1aa6905261f7..b475cc4fa904 100644 --- a/test/dynamo/test_dynamo_dynamic_shape.py +++ b/test/dynamo/test_dynamo_dynamic_shape.py @@ -45,7 +45,7 @@ def _get_linear_and_input(self, in_dim: int, out_dum: int, batch_dim: int, def test_dynamic_shape_basic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( 10, 20, 20, device) @@ -78,7 +78,7 @@ def test_dynamic_shape_basic(self): def test_dynamic_shape_basic_with_mark_dynamic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( 10, 40, 40, device) @@ -123,7 +123,7 @@ def test_dynamic_shape_basic_with_mark_dynamic(self): def test_dynamic_shape_multiple_batchs(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup in_dim = 16 out_dum = 32 @@ -180,7 +180,7 @@ def test_dynamic_shape_multiple_batchs(self): def test_dynamic_shape_mix_with_non_dynamic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup in_dim = 15 out_dum = 31 @@ -238,7 +238,7 @@ def test_dynamic_shape_mix_with_non_dynamic(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_dynamic_decoder(self): - device = torch_xla.device() + device = torch.device('xla') config = DecoderOnlyConfig() config.num_hidden_layers = 2 config.hidden_size = 512 @@ -257,7 +257,7 @@ def test_dynamic_decoder(self): self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2) def test_dynamic_shape_decoder_mark_dynamic(self): - device = torch_xla.device() + device = torch.device('xla') config = DecoderOnlyConfig() config.num_hidden_layers = 2 config.hidden_size = 512 @@ -276,7 +276,7 @@ def test_dynamic_shape_decoder_mark_dynamic(self): self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2) def test_dynamic_shape_no_retracing(self): - device = torch_xla.device() + device = torch.device('xla') # model setup _, dummy_linear_xla, _, input_xla = self._get_linear_and_input( 8, 10, 20, device) @@ -295,7 +295,7 @@ def test_dynamic_shape_no_retracing(self): "Skip right now because with torch._dynamo.config.inline_inbuilt_nn_modules = True, dynamic compiles takes minutes for resnet18." ) def test_dynamic_shape_resnet18(self): - device = torch_xla.device() + device = torch.device('xla') sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = self._get_loader(device, sample_count, batch_size=4) diff --git a/test/dynamo/test_dynamo_graph_dump.py b/test/dynamo/test_dynamo_graph_dump.py index ae0383a47963..5b35221fcea2 100644 --- a/test/dynamo/test_dynamo_graph_dump.py +++ b/test/dynamo/test_dynamo_graph_dump.py @@ -27,7 +27,7 @@ def test_dump_graph_with_dynamo_execution(self): if not save_file: assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" save_file += '.0' - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.tensor(100.0).to(device) xla_y = torch.tensor(200.0).to(device) res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y) diff --git a/test/dynamo/test_dynamo_integrations_util.py b/test/dynamo/test_dynamo_integrations_util.py index 293bef17ec05..04d1615817d4 100644 --- a/test/dynamo/test_dynamo_integrations_util.py +++ b/test/dynamo/test_dynamo_integrations_util.py @@ -20,7 +20,7 @@ class PybindTest(unittest.TestCase): def test_get_tensors_xla_device_data_node(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) t3 = t2 + t1 @@ -42,7 +42,7 @@ def test_get_tensors_xla_device_data_node(self): assert (expected_tensor_ids == sorted(res_pair[0])) def test_get_base_seed_as_tensor(self): - device = torch_xla.device() + device = torch.device('xla') xm.set_rng_state(23, str(device)) base_seed = torch_xla._XLAC._get_base_seed_as_tensor(str(device)).item() self.assertEqual(23, base_seed) @@ -51,7 +51,7 @@ def test_get_seed_info_id(self): self.assertEqual(torch_xla._XLAC._get_seed_info_id(), -127389) def test_check_tensor_need_materialization(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(20, 5) assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [False]) t1 = t1.to(xla_device) @@ -67,7 +67,7 @@ def test_check_tensor_need_materialization(self): assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [True]) def test_get_graph_hash(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) @@ -85,7 +85,7 @@ def test_get_graph_hash(self): assert (hash == torch_xla._XLAC._get_graph_hash([xla_out_2])) def test_clear_pending_irs(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') torch_xla.sync() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) @@ -104,7 +104,7 @@ def test_clear_pending_irs(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_run_cached_graph(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) diff --git a/test/dynamo/test_graph_input_matcher.py b/test/dynamo/test_graph_input_matcher.py index 70dd0be73f57..7a03139ce029 100644 --- a/test/dynamo/test_graph_input_matcher.py +++ b/test/dynamo/test_graph_input_matcher.py @@ -24,7 +24,7 @@ def get_example_inputs(self): class TestGraphInputMatcher(unittest.TestCase): def test_no_cache_fx_gragh_inputs(self): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = M().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_num_output.py b/test/dynamo/test_num_output.py index b540e0691643..77081e3f2c5e 100644 --- a/test/dynamo/test_num_output.py +++ b/test/dynamo/test_num_output.py @@ -59,7 +59,7 @@ def get_example_inputs(self): class TestNumOutput(unittest.TestCase): def do_test(self, model_class, expected_num_output): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = model_class().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_traceable_collectives.py b/test/dynamo/test_traceable_collectives.py index 45bd89266604..58cdd092cd61 100644 --- a/test/dynamo/test_traceable_collectives.py +++ b/test/dynamo/test_traceable_collectives.py @@ -18,7 +18,7 @@ def collective_broadcast_and_cos(input, src): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): print(f'skip this test for hw {xm.xla_device_hw(device)}') diff --git a/test/eager/test_eager.py b/test/eager/test_eager.py index 552382a2dc39..48acb0958ed4 100644 --- a/test/eager/test_eager.py +++ b/test/eager/test_eager.py @@ -20,7 +20,7 @@ def test_eager_basic(self): xm.wait_device_ops() met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # For some reason randn will also trigger an execution of # size [5, 5] full of 0. @@ -36,7 +36,7 @@ def test_eager_basic(self): def test_eager_recompile(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) xm.wait_device_ops() @@ -55,7 +55,7 @@ def test_eager_recompile(self): def test_eager_in_place(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) xm.wait_device_ops() @@ -67,7 +67,7 @@ def test_eager_in_place(self): def test_eager_random_seed(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') met.clear_all() t1 = torch.randn(12, 13, device=device) @@ -82,7 +82,7 @@ def test_eager_random_seed(self): def test_eager_set_random_seed(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') old_seed = 1234 xm.set_rng_state(old_seed) @@ -95,7 +95,7 @@ def test_eager_set_random_seed(self): def test_batch_norm_execute_once(self): xm.wait_device_ops() - device = torch_xla.device() + device = torch.device('xla') m = nn.BatchNorm2d(16).to(device) m.train() input = torch.randn(8, 16, 8, 32).to(device) @@ -112,7 +112,7 @@ def test_batch_norm_execute_once(self): torch_xla._XLAC._get_xla_tensor_debug_info(m.running_mean)) def test_svd_execute_once(self): - device = torch_xla.device() + device = torch.device('xla') a = torch.randn(5, 3).to(device) xm.wait_device_ops() met.clear_all() diff --git a/test/eager/test_eager_all_reduce_in_place.py b/test/eager/test_eager_all_reduce_in_place.py index 7ea68b7fb6e4..5349212bea0c 100644 --- a/test/eager/test_eager_all_reduce_in_place.py +++ b/test/eager/test_eager_all_reduce_in_place.py @@ -10,7 +10,7 @@ def _mp_fn(index): import torch_xla torch_xla.experimental.eager_mode(True) - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): return diff --git a/test/eager/test_eager_spmd.py b/test/eager/test_eager_spmd.py index 3b05ba7af652..36a8faa931f5 100644 --- a/test/eager/test_eager_spmd.py +++ b/test/eager/test_eager_spmd.py @@ -39,7 +39,7 @@ def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None): return xs.Mesh(device_ids, mesh_shape, axis_names) def test_eager_spmd_basic(self): - device = torch_xla.device() + device = torch.device('xla') mesh = self._get_mesh((self.n_devices,), axis_names=('data',)) torch.manual_seed(100) linear = torch.nn.Linear(10, 20) @@ -52,7 +52,7 @@ def test_eager_spmd_basic(self): self.assertTrue(torch.allclose(res, res_xla.cpu(), atol=1e-2)) def test_module_to_empty_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mlinear = MultiLinear() mlinear.to(device) torch_xla._XLAC._get_xla_sharding_spec(mlinear.linear1.weight) diff --git a/test/eager/test_eager_with_torch_compile.py b/test/eager/test_eager_with_torch_compile.py index e7604658aa5e..c66fbda1bbc6 100644 --- a/test/eager/test_eager_with_torch_compile.py +++ b/test/eager/test_eager_with_torch_compile.py @@ -19,7 +19,7 @@ def dummy_cos_sin(self, tensor): def test_eager_with_compile_basic(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(5, 5, device=device) @@ -38,7 +38,7 @@ def test_eager_with_compile_basic(self): def test_eager_execute_compiled_multiple_times(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(10, 5, device=device) t1.add_(0.5) diff --git a/test/eager/test_eager_with_xla_compile.py b/test/eager/test_eager_with_xla_compile.py index 5aee35b2a12d..3d4b88ce0dfa 100644 --- a/test/eager/test_eager_with_xla_compile.py +++ b/test/eager/test_eager_with_xla_compile.py @@ -29,7 +29,7 @@ def dummy_graph_break(self, t): def test_eager_with_compile_basic(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(5, 5, device=device) @@ -54,7 +54,7 @@ def test_eager_with_compile_basic(self): def test_eager_execute_compiled_multiple_times(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(10, 5, device=device) t1.add_(0.5) @@ -69,7 +69,7 @@ def test_eager_execute_compiled_multiple_times(self): def test_eager_with_compile_graph_break(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) with self.assertRaisesRegex( diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 28b0b7709060..1c7176df8013 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -17,7 +17,7 @@ class TestXMCollectiveOpsTpu(parameterized.TestCase): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -41,7 +41,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_reduce(pin_layout): - device = torch_xla.device() + device = torch.device('xla') # Prevent 0 and 1 from being converted to constants ordinal = xm.send_cpu_data_to_device( torch.tensor( @@ -63,7 +63,7 @@ def test_all_reduce(self, pin_layout): @staticmethod def _all_gather(pin_layout): - device = torch_xla.device() + device = torch.device('xla') ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -80,7 +80,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -105,7 +105,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() tensor = torch.cat( @@ -151,7 +151,7 @@ def callable(input): return input dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -175,7 +175,7 @@ def callable(output, input): return output dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -200,7 +200,7 @@ def callable(output, input): def _all_gather(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(input): output_tensor = [ @@ -229,7 +229,7 @@ def callable(input): def _reduce_scatter(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(output, input): dist.reduce_scatter_tensor(output, input) @@ -254,7 +254,7 @@ def callable(output, input): def _all_to_all_single(use_dynamo: bool, split_size: int = 1): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(output, input): dist.all_to_all_single(output, input) diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index d93bbe45c4d9..62e24b804af2 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -26,7 +26,7 @@ class TestPjRtDistributedDataParallel(parameterized.TestCase): @staticmethod def _ddp_init(index: int = ...): dist.init_process_group('xla', init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(10, 10).to(device) ddp_model = DDP(model) diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index 17892261119a..f9189aa50342 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -33,12 +33,12 @@ class TestPjRtProfiler(absltest.TestCase): def setUp(self): # HACK: ensure libtpu is loaded if using TPU - torch_xla.device() + torch.device('xla') def test_profiler_output(self): tempdir = self.create_tempdir().full_path - device = torch_xla.device() + device = torch.device('xla') ones = torch.ones([5]) with _profile(tempdir): xones = ones.to(device) diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 25c3280ce4b5..71f667765637 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -27,7 +27,7 @@ def test_default_cpu_device(self): os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None) expected = {0: torch.device('xla:0')} - devices_per_process = pjrt.run_multiprocess(xm.xla_device) + devices_per_process = pjrt.run_multiprocess(torch_xla.device) self.assertDictEqual(devices_per_process, expected) def test_multi_cpu_devices(self): @@ -38,7 +38,7 @@ def test_multi_cpu_devices(self): 3: torch.device('xla:3'), } - devices_per_process = pjrt.run_multiprocess(xm.xla_device) + devices_per_process = pjrt.run_multiprocess(torch_xla.device) self.assertDictEqual(devices_per_process, expected) def test_global_ordinal(self): diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 89ad676ca383..aa039166ae67 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -172,7 +172,7 @@ def test_local_ordinal_with_discontiguous_global_ordinal_v4_threaded(self): @staticmethod def _spawn_threads() -> Dict[int, torch.device]: results = {} - pjrt.spawn_threads(lambda i: results.setdefault(i, torch_xla.device())) + pjrt.spawn_threads(lambda i: results.setdefault(i, torch.device('xla'))) return results @@ -187,7 +187,7 @@ def test_spawn_threads(self): @staticmethod def _spawn_error(): # Initialize the client in the parent process - torch_xla.device() + torch.device('xla') torch_xla.launch(xm.xla_device) @@ -199,7 +199,7 @@ def test_spawn_error(self): @staticmethod def _runtime_device_attributes(): - return xr.runtime_device_attributes(str(torch_xla.device())) + return xr.runtime_device_attributes(str(torch.device('xla'))) def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) @@ -226,7 +226,7 @@ def test_global_runtime_device_attributes(self): @staticmethod def _execute_time_metric(): # Initialize the client before starting the timer. - torch_xla.device() + torch.device('xla') begin = time.perf_counter_ns() value = ( diff --git a/test/pjrt/test_train_hf_transformer.py b/test/pjrt/test_train_hf_transformer.py index d484edc0a6ce..93d932bab7a8 100644 --- a/test/pjrt/test_train_hf_transformer.py +++ b/test/pjrt/test_train_hf_transformer.py @@ -55,7 +55,7 @@ def finetune(rank, train_dataset, test_dataset, tokenizer, flags): drop_last=True, generator=rng) - device = torch_xla.device() + device = torch.device('xla') model = AutoModelForSequenceClassification.from_pretrained( 'google-bert/bert-base-cased', num_labels=5) model.to(device) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 3355f8efba99..bb3b7b8114c4 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval): def instantiate_test(cls, name, test, *, generic_cls): test_name = name + '_' + cls.device_type class_name = cls.__name__ - real_device_type = xm.xla_device_hw(str(torch_xla.device())) + real_device_type = xm.xla_device_hw(str(torch.device('xla:0'))) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @@ -631,7 +631,7 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): - # Sets the primary test device to the xla_device (CPU or TPU) + # Sets the primary test device to the torch_xla.device (CPU or TPU) cls.primary_device = str(torch_xla.device()) torch_xla._XLAC._xla_set_mat_mul_precision('highest') diff --git a/test/quantized_ops/test_dot_general.py b/test/quantized_ops/test_dot_general.py index 71a39ff56e96..846da4f0255a 100644 --- a/test/quantized_ops/test_dot_general.py +++ b/test/quantized_ops/test_dot_general.py @@ -5,7 +5,7 @@ import torch_xla import unittest -device = torch_xla.device() +device = torch.device('xla') torch.manual_seed(12345) diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index 88a34c69a4ae..ace38bfee083 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -12,7 +12,7 @@ torch.manual_seed(123456) -device = torch_xla.device() +device = torch.device('xla') class M(torch.nn.Module): diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index 42c362ee8769..fbf7d5a4cded 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -45,7 +45,7 @@ class TestBase(XlaTestCase): def setUp(self): super().setUp() - self.device = torch_xla.device() + self.device = torch.device('xla') # Clear the scan computation cache before each test to avoid cross-test contamination. scan_module._SCAN_COMPUTATION_CACHE.clear() @@ -288,7 +288,7 @@ def test_scan_external_in_place_mutation(self): giving wrong results. """ # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. - weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + weird_global = torch.tensor([0.0, 0.0], device='xla') def step_fn(carry, x): new_carry = carry + x @@ -296,9 +296,8 @@ def step_fn(carry, x): y = new_carry + weird_global return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') with self.assertRaisesRegex(AssertionError, "FakeTensor"): scan(step_fn, init, xs) @@ -371,9 +370,8 @@ def step_fn(carry, x): y = new_carry + torch.rand(2, device=torch_xla.device()) return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') _, ys = scan(step_fn, init, xs) # ys should be a 2D tensor with this shape. self.assertEqual(ys.shape, (3, 2)) diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py index ba193ea1eb30..e093e83ccace 100644 --- a/test/scan/test_scan_layers.py +++ b/test/scan/test_scan_layers.py @@ -26,7 +26,7 @@ class ScanLayersTest(XlaTestCase): def setUp(self): super().setUp() - self.device = torch_xla.device() + self.device = torch.device('xla') def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor): assert a is not b, f"Expected {a} and {b} to be different tensors" diff --git a/test/scan/test_scan_pallas.py b/test/scan/test_scan_pallas.py index a267886cd3f7..6f77d1b52aa7 100644 --- a/test/scan/test_scan_pallas.py +++ b/test/scan/test_scan_pallas.py @@ -72,7 +72,7 @@ def fake_fa_wrapper(self, has_model_weight, use_scan): torch.manual_seed(12) torch_xla.manual_seed(12) hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla') - with torch_xla.device(): + with torch.device('xla'): attention_layers = AttentionLayers( has_model_weight, num_layer=3, use_scan=use_scan) hidden_states.retain_grad() diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index 9bf081527c72..2bd1428a842c 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -23,7 +23,7 @@ def setUp(self): # Set up a simple SPMD mesh for these tests. self.spmd_mesh = get_1d_mesh(axis_name="model") set_global_mesh(self.spmd_mesh) - self.device = torch_xla.device() + self.device = torch.device('xla') @unittest.skipUnless(xr.global_runtime_device_count() >= 4, "Multiple devices required") diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 518e4203b459..71b109b90d12 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -42,7 +42,7 @@ def setUpClass(cls): super().setUpClass() def test_dynamo_spmd_basic(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -58,7 +58,7 @@ def test_dynamo_spmd_basic(self): # a ExecuteMetric. def test_dynamo_spmd_output_sharding_spec(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -74,7 +74,7 @@ def test_dynamo_spmd_output_sharding_spec(self): ) def test_dynamo_spmd_output_sharding_cache(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -90,7 +90,7 @@ def test_dynamo_spmd_output_sharding_cache(self): self.assertEqual(met.counter_value('UncachedOutputSharding'), 1) def test_dynamo_sharded_input(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -103,7 +103,7 @@ def test_dynamo_sharded_input(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_input_sharding_changed(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -142,7 +142,7 @@ def test_dynamo_input_sharding_changed(self): @unittest.skipIf(xr.global_runtime_device_count() == 1, "Multiple devices needed to test the mesh change") def test_dynamo_input_sharding_threashold(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -183,7 +183,7 @@ def test_dynamo_input_sharding_threashold(self): del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -202,7 +202,7 @@ def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_spmd_activation_sharding_with_dynamo_mark_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mesh = self._get_mesh((1, self.n_devices)) device_ids = mesh.device_ids.tolist() mesh_shape = list(mesh.mesh_shape) diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py index dc1e4aba12b0..135215e5b72c 100644 --- a/test/spmd/test_mp_input_sharding.py +++ b/test/spmd/test_mp_input_sharding.py @@ -34,7 +34,7 @@ def __next__(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_multiple_inputs(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -61,7 +61,7 @@ def test_multiple_inputs(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_single_tensor(self): - device = torch_xla.device() + device = torch.device('xla') batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -78,7 +78,7 @@ def test_single_tensor(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_single_tensor_with_input_sharding_dict(self): - device = torch_xla.device() + device = torch.device('xla') batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -95,7 +95,7 @@ def test_error_single_tensor_with_input_sharding_dict(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_none(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -112,7 +112,7 @@ def test_input_sharding_none(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_missing_keys(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) mesh = xs.get_1d_mesh('x') @@ -127,7 +127,7 @@ def test_error_missing_keys(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_not_dict(self): - device = torch_xla.device() + device = torch.device('xla') num_devices = xr.global_runtime_device_count() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} train_loader = self.fake_dataloader(batch) diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index 2dd09580a5a6..8f31fc3dde51 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -146,7 +146,7 @@ def training_step(data): torch.manual_seed(42) tries = 5 -device = torch_xla.device() +device = torch.device('xla') if args.profile: print("Profiler server started at port 9012") server = xp.start_server(9012) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 34221d375e9c..def91adef995 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -209,7 +209,7 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -252,7 +252,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -295,7 +295,7 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 45af3b154934..a0c7011f914f 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -26,7 +26,7 @@ def test_dump_with_output_sharding(self): assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 32).to(device) xla_y = torch.randn(8, 32).to(device) # shard one of the input tensor diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index 9bc80194318f..6f6307ab0676 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -38,7 +38,7 @@ def test_basic(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = torch_xla.device() + device = torch.device('xla') a = torch.zeros(2048, device=device, requires_grad=True) xs.mark_sharding(a, spmd_mesh, ('x',)) b = torch.randn([32, 2048], device=device, requires_grad=True) @@ -108,7 +108,7 @@ def test_device_parameter_id_tensor_mapping(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = torch_xla.device() + device = torch.device('xla') a = torch.randn([32, 2048]).to(device) xs.mark_sharding(a, spmd_mesh, ('x', 'y')) b = torch.ones(2048).to(device) diff --git a/test/spmd/test_spmd_parameter_wrapping.py b/test/spmd/test_spmd_parameter_wrapping.py index 47f1bab8d33f..94267fab2b32 100644 --- a/test/spmd/test_spmd_parameter_wrapping.py +++ b/test/spmd/test_spmd_parameter_wrapping.py @@ -38,7 +38,7 @@ def setUpClass(cls): super().setUpClass() def test_fsdpv2(self): - device = torch_xla.device() + device = torch.device('xla') one_d_mesh = xs.get_1d_mesh("fsdp") xs.set_global_mesh(one_d_mesh) linears = MultiLinear() @@ -56,7 +56,7 @@ def test_fsdpv2(self): self.assertEqual(output.shape, torch.Size([100, 40])) def basic_spmd_test(self): - device = torch_xla.device() + device = torch.device('xla') one_d_mesh = xs.get_1d_mesh("data") input = torch.randn(8, 128) input2 = torch.randn(8, 128) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 006db02dd46e..935470d82446 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -206,7 +206,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')().to(device) if FLAGS.use_gradient_checkpointing: diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index d3fa093e8b13..4223b7b82ee5 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -76,7 +76,7 @@ def _assert_same_state_dict(self, sd1, sd2, keypath=""): if isinstance(sd1, torch.Tensor): assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}" - if sd1.device == torch_xla.device(): + if sd1.device == torch.device('xla'): sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1) sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2) assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1f..455d9006078d 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -277,7 +277,7 @@ def test_mark_sharding_4d(self): self.assertTrue(torch.allclose(expected, actual)) def test_mark_sharding_not_ordered_sharding_spec_2d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 16, device='cpu') expected = t1 + t1 @@ -290,7 +290,7 @@ def test_mark_sharding_not_ordered_sharding_spec_2d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_3d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 8, 16, device='cpu') expected = t1 + t1 @@ -307,7 +307,7 @@ def test_mark_sharding_not_ordered_sharding_spec_3d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_4d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(32, 4, 8, 16, device='cpu') expected = t1 + t1 @@ -326,7 +326,7 @@ def test_mark_sharding_not_ordered_sharding_spec_4d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_partial(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) # Somehow the eager cpu result is different from the xla result. @@ -356,7 +356,7 @@ def test_mark_sharding_partial(self): self.assertTrue(torch.allclose(expected, actual)) def test_propagate_replicated_sharding(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) t3 = t1 @ t2 @@ -368,7 +368,7 @@ def test_propagate_replicated_sharding(self): self.assertIn("replicated", torch_xla._XLAC._get_xla_sharding_spec(t3)) def test_mark_sharding_partial_unordered(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 3, 4).to(device) t2 = torch.randn(4, 3, 4).to(device) expected = t1 + t2 @@ -467,7 +467,7 @@ def test_3d_tensor_2d_mesh(self): (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) def test_partial_replication_addmm(self): - device = torch_xla.device() + device = torch.device('xla') z_dim = 2 if self.n_devices >= 4 else 1 mesh = self._get_mesh((z_dim, self.n_devices // z_dim)) @@ -657,7 +657,7 @@ def test_send_cpu_data_to_device_with_sharding(self): sharding_spec = xs.ShardingSpec(mesh, (0, 1)) self.assertTrue(sharding_spec.can_apply(tensor)) xtensors = xm.send_cpu_data_to_device([tensor], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec) self.assertEqual(len(xtensors), 1) outbound = met.metric_data("OutboundData")[1] @@ -955,7 +955,7 @@ def test_named_partition_spec(self): self.assertTrue("replicated" in sharding_spec) def test_shard_device_data_ir(self): - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 128, device=device) # xla_x now becomes a device data IR xla_y = xla_x * 5 @@ -967,7 +967,7 @@ def test_shard_device_data_ir(self): self.assertTrue(torch.allclose(xla_y.cpu(), xla_x.cpu() * 5)) def test_shard_device_data_ir_after_sync(self): - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 128, device=device) x = xla_x.cpu() # xla_x now becomes a device data IR without XLAData @@ -1370,7 +1370,7 @@ def test_spmd_all_reduce_scale(self): self.assertTrue(torch.allclose(x.cpu(), expected_x)) def test_get_1d_mesh(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") t1 = torch.randn(8, 8).to(device) xt = xs.mark_sharding(t1, mesh, ("data", None)) @@ -1387,7 +1387,7 @@ def test_get_1d_mesh(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = 8 train_loader = xu.SampleGenerator( @@ -1410,7 +1410,7 @@ def test_data_loader_with_sharding(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = mesh.size() - 1 train_loader = xu.SampleGenerator( @@ -1433,7 +1433,7 @@ def test_data_loader_with_non_batch_size(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size_and_mini_batch(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = mesh.size() - 1 train_loader = xu.SampleGenerator( @@ -1453,7 +1453,7 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self): data, _ = iter(train_device_loader).__next__() def test_fallback(self): - device = torch_xla.device() + device = torch.device('xla') theta: float = 10000 dim = 16 @@ -1487,7 +1487,7 @@ def test_xla_patched_linear(self): import torch_xla.core.xla_model as xm import torch.nn.functional as F - with torch_xla.device(): + with torch.device('xla'): torch_xla.manual_seed(42) x0 = torch.randn(2, 3, requires_grad=True) w0 = torch.randn(4, 3, requires_grad=True) diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index ba051964a108..741392f89562 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -42,18 +42,18 @@ def test_xla_device(self): self.assertEqual(device, torch.device('xla:0')) def test_xla_real_devices(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) def test_xla_device_hw(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_device_hw(device), device_type) def test_xla_replication_devices(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) @@ -148,7 +148,7 @@ def setUpClass(cls): @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.ones([2, 3], device=device, dtype=torch.float32) t2 = torch.ones([3, 2], device=device, dtype=torch.float32) with autocast(device, dtype=torch.bfloat16): diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 60c8b31a00e9..9c9f2b3a49ec 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -88,7 +88,7 @@ def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec)[0] # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. @@ -101,7 +101,7 @@ def test_sync_on_virtual_device(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec)[0] xt2 = xt1 / 0.5 torch_xla.sync(wait=True) @@ -111,7 +111,7 @@ def test_sync_on_virtual_device(self): def test_virtual_device_no_upload(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) # t1's upload to device should be deferred @@ -125,7 +125,7 @@ def test_virtual_device_no_upload(self): def test_virtual_device_upload_after_mark_sharding(self): met.clear_all() partition_spec = (0, 1) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -139,7 +139,7 @@ def test_virtual_device_upload_after_mark_sharding(self): def test_virtual_device_upload_after_tracing(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -152,7 +152,7 @@ def test_virtual_device_upload_after_tracing(self): def test_virtual_device_upload_for_sharded_dataloader(self): met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)], diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 8fe211475ba1..8f95f5b73967 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -71,7 +71,7 @@ class XlaMarkPatternTest(unittest.TestCase): def run_func_get_stablehlo(self, f, input_args): - device = torch_xla.device() + device = torch.device('xla') input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), input_args) exported = torch.export.export(AsModule(f), input_args) diff --git a/test/stablehlo/test_implicit_broadcasting.py b/test/stablehlo/test_implicit_broadcasting.py index 10fbe5789981..04c5dd882a7b 100644 --- a/test/stablehlo/test_implicit_broadcasting.py +++ b/test/stablehlo/test_implicit_broadcasting.py @@ -10,7 +10,7 @@ # The following tests cover the implcit-broadcasting for static and bounded # dynamic shapes. -device = torch_xla.device() +device = torch.device('xla') class ImplicitBroadcasting(unittest.TestCase): diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 34426f978029..a5b32cd7d5bc 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -55,7 +55,7 @@ def count_qdq_ops(g: torch.fx.Graph): class PT2EExportTest(unittest.TestCase): def test_per_tensor_qdq(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 3, 4, 5).to(device) x = torch.ops.quantized_decomposed.quantize_per_tensor( x, 0.4, 2, -128, 127, torch.int8) @@ -69,7 +69,7 @@ def test_per_tensor_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) def test_per_channel_qdq(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 3, 4, 5).to(device) scale = torch.tensor([3.2, 5.3, 0.1, 10]).to(device) zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int64).to(device) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index a57faf7ff5f2..a5abc0d27498 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -21,7 +21,7 @@ def test_resnet18_stablehlo_compile(self): torch_input = torch.tensor(np_input).float() cpu_output = resnet18(torch_input) # Run ResNet on XLA device. - device = torch_xla.device() + device = torch.device('xla') # materalize the fake data for test purpose torch_xla.sync() xm.wait_device_ops() diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index a315bbc230db..3a73c93f1005 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -118,7 +118,7 @@ def forward(self, x): # self.assertTrue("api_version = 1" in shlo_text) def test_place_to_host_device(self): - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev) b = place_to_host(a) shlo_text = xm.get_stablehlo([b]) @@ -137,7 +137,7 @@ def test_place_to_host_device(self): def test_place_to_host_device_autograd(self): # Test that gradient can flow through place_to_host and place_to_device ops. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev, requires_grad=True) b = place_to_host(a) c = b.sum() @@ -155,7 +155,7 @@ def test_place_to_host_device_aot_autograd(self): # specifically `aot_function`. from functorch.compile import aot_function, make_boxed_func # type: ignore - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev, requires_grad=True) def my_fn(x): diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index a29b66ebceaa..0e1b0ffdfe25 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -67,7 +67,7 @@ def forward(self, x, y): output = m(*data) exported = export_torch_model(m, data) - device = torch_xla.device() + device = torch.device('xla') data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data).cpu() @@ -91,7 +91,7 @@ def forward(self, inputs): output = m(*data) exported = export_torch_model(m, data) - device = torch_xla.device() + device = torch.device('xla') data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data) self.assertEqual(len(output2), 2) diff --git a/test/stablehlo/test_stablehlo_save_load.py b/test/stablehlo/test_stablehlo_save_load.py index 71ff463578cb..5d353e25a386 100644 --- a/test/stablehlo/test_stablehlo_save_load.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -17,7 +17,7 @@ class StableHloDumpTest(unittest.TestCase): def test_simple(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.tensor([3], device=device) y = torch.tensor([3], device=device) z = x + y @@ -26,7 +26,7 @@ def test_simple(self): self.assertEqual(stablehlo.count("stablehlo.add"), 1) def test_resnet18(self): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18() xla_resnet18.eval() xla_resnet18 = xla_resnet18.to(device) @@ -66,7 +66,7 @@ class SimpleExportTest(unittest.TestCase): def export_stable_hlo(self, model, args, kwargs=None): if kwargs is None: kwargs = {} - device = torch_xla.device() + device = torch.device('xla') model.eval() model = model.to(device) args = tuple(i.to(device) for i in args if hasattr(i, 'to')) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 88fce368b668..5f6a853c87ba 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -19,7 +19,7 @@ compare_exported_program_and_saved_model_result, has_tf_package, load_save_model_and_inference, wrap_func_as_nn_module) -device = torch_xla.device() +device = torch.device('xla') os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1' diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 51a73a402703..5c336b3b6bcf 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -7,7 +7,7 @@ import torch_xla.core.xla_model as xm from torch_xla.stablehlo import exported_program_to_stablehlo -device = torch_xla.device() +device = torch.device('xla') class XLAExportInterpreterTest(unittest.TestCase): diff --git a/test/test_autocast.py b/test/test_autocast.py index ca1f26c05ec1..468e8b932061 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -282,7 +282,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(self.is_autocast_enabled()) - with autocast(torch_xla.device(), dtype=autocast_dtype): + with autocast(torch.device('xla'), dtype=autocast_dtype): self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -332,7 +332,7 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(torch_xla.device(), enabled=False): + with autocast(torch.device('xla'), enabled=False): self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): @@ -355,7 +355,7 @@ class TestAutocastTPU(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTPUTestLists(torch.device(torch_xla.device())) + cls.autocast_lists = AutocastTPUTestLists(torch.device(torch.device('xla'))) def setUp(self): super(TestAutocastTPU, self).setUp() @@ -397,7 +397,7 @@ def test_autocast_methods_expect_builtin_promote(self): op, args, torch.float32, module=None, out_type=out_type) def test_autocast_tpu_check_dtype(self): - with autocast(torch_xla.device(), dtype=torch.float16): + with autocast(torch.device('xla'), dtype=torch.float16): assert not torch.is_autocast_xla_enabled() @@ -408,7 +408,7 @@ class TestOtherOps(unittest.TestCase): xm.xla_device_hw(torch_xla.device()) != 'TPU', "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index e287cb1bae55..cc8bed32b44c 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -6,7 +6,7 @@ import torch_xla.distributed.spmd.xla_sharding as xs -device = torch_xla.device() +device = torch.device('xla') class TestAutocastXla(unittest.TestCase): diff --git a/test/test_compilation_cache_utils.py b/test/test_compilation_cache_utils.py index 0ac8a013d814..5ca4a0bfb396 100644 --- a/test/test_compilation_cache_utils.py +++ b/test/test_compilation_cache_utils.py @@ -31,7 +31,7 @@ def _test_spawn(fn, args): class TestGraphHash(parameterized.TestCase): def _test_num_graph_hash(self, use_dynamo, use_persistent): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = M().to(device=xla_dev) input_shape = (10, 5) if use_dynamo: diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 361c09de7faf..4f099900a83f 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -46,7 +46,7 @@ def run_export_and_compare(testcase, atol=1e-3, rtol=1e-5, equal_nan=True): - device = torch_xla.device() + device = torch.device('xla') with testcase.subTest('torch_eval'): res = func(*args, **kwargs) with testcase.subTest('torch_xla_eval'): diff --git a/test/test_data_type.py b/test/test_data_type.py index da4b7d00681f..5cdbf4aa2c2f 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -55,7 +55,7 @@ def test_datatype_use_32bit_long(self): self._test_datatype(torch.uint64, 'u32', torch.add) def test_module_to_dtype(self): - device = torch_xla.device() + device = torch.device('xla') linear = torch.nn.Linear( 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) input = torch.randn(10, 5).to(device).to(torch.bfloat16) diff --git a/test/test_env_var_mapper.py b/test/test_env_var_mapper.py index e4dcef2ba8cb..95d5c99595af 100644 --- a/test/test_env_var_mapper.py +++ b/test/test_env_var_mapper.py @@ -15,7 +15,7 @@ def check_env_flag(name, default=''): class EnvVarMapperTest(unittest.TestCase): def test_xla_ir_debug_(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') with xp.Trace('test_xla_ir_debug'): t = torch.tensor([2.0, 3.0], dtype=torch.float, device=xla_device) diff --git a/test/test_fp8.py b/test/test_fp8.py index 2dbf534cb5c7..fc00e0a932ac 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -6,7 +6,7 @@ import unittest from absl.testing import parameterized -device = torch_xla.device() +device = torch.device('xla') dtype_parameters = [ torch.float8_e5m2, diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 019612899697..55ef03881a88 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -35,7 +35,7 @@ def forward(self, x): "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" ) def test(self): - dev = torch_xla.device() + dev = torch.device('xla') input = torch.zeros([16, 16], device=dev) model = self.MyModel(input_size=16, hidden_size=4) model = XlaFullyShardedDataParallel( @@ -48,7 +48,7 @@ def test(self): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py index e4e318ba8310..42c9e4e93686 100644 --- a/test/test_grad_checkpoint.py +++ b/test/test_grad_checkpoint.py @@ -11,7 +11,7 @@ def run(): - device = torch_xla.device() + device = torch.device('xla') model = torch.nn.ModuleList([ torch.nn.Sequential( torch.nn.Conv2d(1024, 1024, 1), diff --git a/test/test_gradient_accumulation.py b/test/test_gradient_accumulation.py index 62ecfc431132..cfe324ed0f54 100644 --- a/test/test_gradient_accumulation.py +++ b/test/test_gradient_accumulation.py @@ -23,7 +23,7 @@ def forward(self, x): class GradAccumulationTest(XlaTestCase): def setUp(self): - self.device = torch_xla.device() + self.device = torch.device('xla') torch.manual_seed(123) def test_basic(self): diff --git a/test/test_inplace_update.py b/test/test_inplace_update.py index 704888d4f6e7..9e718d29ad17 100644 --- a/test/test_inplace_update.py +++ b/test/test_inplace_update.py @@ -11,7 +11,7 @@ class InplaceUpdateTest(unittest.TestCase): def test_aten_op_after_full_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -21,7 +21,7 @@ def test_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_aten_op_after_partial_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -31,7 +31,7 @@ def test_aten_op_after_partial_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_full_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -41,7 +41,7 @@ def test_non_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_partial_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -53,7 +53,7 @@ def test_non_aten_op_after_partial_update(self): def test_xm_save(self): with temporary_env( XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor([1], device=xla_device) t2 = t1.detach() torch_xla.sync() diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 3f20f9d25c97..be3789f08785 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,7 +38,7 @@ def config_context(value): class InputOutputAliasesTest(parameterized.TestCase): def test_non_view(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(4, 2, 2).to(xla_device) t2 = torch.randn(4, 2, 2).to(xla_device) torch_xla.sync() @@ -53,7 +53,7 @@ def test_non_view(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_with_cloned(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 2, 2).to(xla_device) # t1_cloned share the same storage as t1 @@ -66,7 +66,7 @@ def test_aliasing_with_cloned(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) def test_aliasing_across_custom_inplace(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 *= t1 @@ -78,7 +78,7 @@ def test_aliasing_across_custom_inplace(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_across_sync(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 += 1 @@ -96,7 +96,7 @@ def test_aliasing_with_multiple_inplace_update(self): BLOCK_SIZE = 16 DTYPE = torch.bfloat16 num_blocks = 1024 - device = torch_xla.device() + device = torch.device('xla') key = torch.randn( BATCH_SIZE * SEQ_LEN, NUM_KV_HEADS, @@ -145,7 +145,7 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps): torch_xla.sync() return [p.grad.to('cpu').numpy() for p in model.parameters()] - dev = torch_xla.device() + dev = torch.device('xla') train_x_sample = torch.rand((1, 28 * 28)) train_label_sample = torch.tensor([5]) c_model = MLP().to('cpu') @@ -171,7 +171,7 @@ def test_separate_graphs(self): """ Test that paramater aliasing differences should produce different graphs. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -190,7 +190,7 @@ def test_xm_save_no_aliasing(self): """ Test that xm.save() does not perform aliasing. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -212,7 +212,7 @@ def test_device_data_cache_no_aliasing(self): """ Test that device data in DataCache are not aliased. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor(42, device=xla_device) # drops the read-only bit on t0's device_data @@ -235,7 +235,7 @@ def test_device_data_cache_no_aliasing(self): def test_user_config_donation_with_ltc_donation(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -255,7 +255,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( self, enable_buffer_donor_config): with alias_with_buffer_donor_config_context(enable_buffer_donor_config): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -279,7 +279,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( def test_user_config_donation_with_ltc_donation_overlap(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -291,7 +291,7 @@ def test_user_config_donation_with_ltc_donation_overlap(self): def test_user_config_donation(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -308,7 +308,7 @@ def test_user_config_donation(self): def test_user_config_donation_inplace_aliasing(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -322,7 +322,7 @@ def test_user_config_donation_inplace_aliasing(self): def test_user_config_donation_no_op_sync(self): with alias_with_buffer_donor_config_context(True): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) torch_xla.sync() @@ -331,7 +331,7 @@ def test_user_config_donation_no_op_sync(self): self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) def test_no_op_sync_keep_buffer_donation(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') input = torch.randn(5, 5).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) torch_xla.sync() @@ -346,7 +346,7 @@ def test_device_data_node_tracing_aliasing(self): for a given set of unmutated input tensor during its tracing. This helps ensure that aliasings can be retained if using the binding for tracing purposes. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor(10).to(xla_device) t1 = t0 + 5 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index 5016462b982e..f852f239d524 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -14,7 +14,7 @@ def setUp(self): def test_call_jax(self): """Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): @@ -29,7 +29,7 @@ def f(a, b): def test_call_jax_input_pytree(self): """Test that call_jax works with PyTree inputs.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((2, 2), device=dev) b = torch.ones((2, 2), device=dev) * 2 @@ -55,7 +55,7 @@ def f(inputs): def test_call_jax_output_pytree(self): """Test that call_jax works with PyTree outputs.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((2, 2), device=dev) def f(a): @@ -89,7 +89,7 @@ def f(a): def test_call_jax_some_arg_unused(self): """Test when the jax function doesn't use some input arguments.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.randn((3, 3), device=dev) b = torch.randn((3, 3), device=dev) c = torch.randn((3, 3), device=dev) @@ -106,7 +106,7 @@ def f(a, b, c, d): def test_call_jax_grad(self): """Test calling a simple jax.grad transformed function.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.randn((3, 3), device=dev, requires_grad=True) b = torch.randn((3, 3), device=dev, requires_grad=True) torch_xla.sync() @@ -143,7 +143,7 @@ def f_jax(a, b): def test_call_jax_non_tensor_args(self): """Test that call_jax works with non-tensor arguments.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, num: float, string: str, dictionary: dict, none): @@ -173,7 +173,7 @@ def test_call_jax_cache_hlo(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace two different jax functions a couple of times. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): @@ -198,7 +198,7 @@ def test_call_jax_cache_by_shape(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different shapes. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) b = torch.ones((2, 2), device=dev) @@ -217,7 +217,7 @@ def test_call_jax_cache_by_tree_spec(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different tree specs. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) b = torch.ones((3, 2), device=dev) @@ -237,7 +237,7 @@ def test_call_jax_cache_by_static_args(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different static args. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, num: float): @@ -255,7 +255,7 @@ def test_call_jax_with_different_jax_config(self): import jax starting_cache_misses = xb._jax_to_xla_computation_cache_elements() - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): diff --git a/test/test_metrics.py b/test/test_metrics.py index 69b9ab20a656..6cbbc9fc3340 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -24,7 +24,7 @@ def check_metrics_file(): class MetricsTest(unittest.TestCase): def test_clear_counters(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t1 += 2 self.assertIn("xla::add", met.metrics_report()) @@ -39,7 +39,7 @@ def test_clear_counters(self): assert (len(met.counter_names()) > 0) def test_clear_metrics(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(156, device=xla_device) self.assertIn("TensorToData", met.metrics_report()) assert (len(met.metric_names()) > 0) @@ -52,7 +52,7 @@ def test_clear_metrics(self): assert (len(met.metric_names()) > 0) def test_tracing_time_metrics(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -61,7 +61,7 @@ def test_tracing_time_metrics(self): def test_eager_metrics(self): with torch_xla.experimental.eager_mode_context(True): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -78,7 +78,7 @@ def test_eager_metrics(self): self.assertNotIn('ExecuteTime', met.metric_names()) def test_short_metrics_report_default_list(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(1456, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -100,7 +100,7 @@ def test_short_metrics_report_default_list(self): assert check_metrics_file() def test_short_metrics_report_custom_list(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 t1 += 2 @@ -120,7 +120,7 @@ def test_short_metrics_report_custom_list(self): self.assertIn('InputOutputAliasCount', short_report) def test_short_metrics_fallback_counter(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 # this will trigger a aten::_local_scalar_dense which is the same as fallback counter @@ -135,7 +135,7 @@ def test_short_metrics_fallback_counter(self): def test_metrics_report(self): # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToDeviceAsync, IrValueTensorToXlaData - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(2077, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -207,7 +207,7 @@ def test_metrics_report(self): @unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. - torch_xla.device() + torch.device('xla') begin = time.perf_counter_ns() value = torch.randn( @@ -226,7 +226,7 @@ def test_execute_time_metric(self): def test_pybind_increment_counter(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(2077, device=xla_device) self.assertEqual(met.counter_value('CreateXlaTensor'), 1) torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 93d64f47ef3e..c364b69875c2 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -11,7 +11,7 @@ def all_gather(tensor, dim): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() input_list_size = 5 if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): diff --git a/test/test_mp_all_to_all.py b/test/test_mp_all_to_all.py index 9761507dea13..5a041463c7cd 100644 --- a/test/test_mp_all_to_all.py +++ b/test/test_mp_all_to_all.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'NEURON'): slots_per_device = 4 size = slots_per_device * xr.world_size() diff --git a/test/test_mp_collective_matmul.py b/test/test_mp_collective_matmul.py index 29f115c986cd..4e18fee0de2c 100644 --- a/test/test_mp_collective_matmul.py +++ b/test/test_mp_collective_matmul.py @@ -8,7 +8,7 @@ def _mp_fn(index): os.environ["ENABLE_COLLECTIVE_MATMUL_IN_MP"] = "1" - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() groups = [[i for i in range(world_size)]] scale = 1 / world_size diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 81a1eb771bcd..31f9cc94ae3b 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 7d6c7982cb2f..c6630e1a0a04 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -7,7 +7,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py index 89e46722e232..275fb353c8db 100644 --- a/test/test_mp_early_exit.py +++ b/test/test_mp_early_exit.py @@ -12,7 +12,7 @@ def _mp_fn(): dist.init_process_group('xla', init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ['TPU', 'CUDA']: train_loader = xu.SampleGenerator( data=torch.zeros(1, 12), sample_count=1024) diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 12fc7fdfe1c8..375b8cc85b17 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() scale = 1 / world_size scatter_dim = 1 diff --git a/test/test_mp_replication.py b/test/test_mp_replication.py index 61a302a65784..c21a4b83629e 100644 --- a/test/test_mp_replication.py +++ b/test/test_mp_replication.py @@ -10,7 +10,7 @@ def all_reduce(tensor): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() if world_size > 1: ones = torch.ones((2, 3)) diff --git a/test/test_mp_save.py b/test/test_mp_save.py index ae9f46df120a..4ab45e9d81a7 100644 --- a/test/test_mp_save.py +++ b/test/test_mp_save.py @@ -35,7 +35,7 @@ def _get_data_str(data): def _mp_fn(index, temp_file): - device = torch_xla.device() + device = torch.device('xla') dd = _create_state_dict(device) xm.save(dd, temp_file) # User needs to manually rendezvous since only master process diff --git a/test/test_mp_sync_batch_norm.py b/test/test_mp_sync_batch_norm.py index fa4f18ad00d2..0ac2f720099d 100644 --- a/test/test_mp_sync_batch_norm.py +++ b/test/test_mp_sync_batch_norm.py @@ -47,7 +47,7 @@ def _sync_bn1d_no_channel(rank): t_global = torch.rand((xr.world_size() * bsz, length)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) @@ -72,7 +72,7 @@ def _sync_bn1d_multi_channel(rank): t_global = torch.rand((xr.world_size() * bsz, features, length)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -97,7 +97,7 @@ def _sync_bn2d(rank): t_global = torch.rand((xr.world_size() * bsz, features, h, w)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -122,7 +122,7 @@ def _sync_bn3d(rank): t_global = torch.rand((xr.world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) diff --git a/test/test_operations.py b/test/test_operations.py index 3f6774e87413..6b0721684264 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -179,7 +179,7 @@ def onlyIfPJRTDeviceIsCUDA(fn): class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) @@ -307,7 +307,7 @@ def loop_fn(model, loader, device, context): class TestLongGraphChain(test_utils.XlaTestCase): def test(self): - device = torch_xla.device() + device = torch.device('xla') orig_x = torch.Tensor([[1, 2], [3, 4]]) orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) x = orig_x @@ -440,9 +440,9 @@ def test_nonzero_cast(self): class TestOptimizationBarrier(test_utils.XlaTestCase): def test_optimization_barrier_correctness(self): - device = torch_xla.device() + device = torch.device('xla') # only test optimization_barrier on TPU - if xm.xla_device_hw(device) != 'TPU': + if xr.device_type() != 'TPU': return x = torch.randn(5, 5, device=device) y = torch.randn(5, 5, device=device) @@ -459,7 +459,7 @@ def op_fn(a): return xb.Op.tuple((a, a.cast(xb.Type.BF16))) op = xor.register('test_mixed_dtype_tuple', op_fn) - xla_device = torch_xla.device() + xla_device = torch.device('xla') a_tensor = torch.randn([2, 3]).to(xla_device) a_result, a_cast = op(a_tensor) self.assertEqual(a_result.dtype, torch.float) @@ -530,7 +530,7 @@ def test_amp_foreach_non_finite_check_and_unscale_(self): found_inf_output0 = torch.tensor(0, dtype=torch.float32) found_inf_output1 = torch.tensor(1, dtype=torch.float32) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_grads0 = grads0.to(xla_device) xla_inv_scale = inv_scale.to(xla_device) xla_found_inf = found_inf.to(xla_device) @@ -627,7 +627,7 @@ def test_no_storage(self): def test_slice_copy(self): a = torch.rand(3, 3, 3) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -638,7 +638,7 @@ def test_slice_copy(self): def test_slice_assign(self): a = torch.rand(3, 3, 3) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -649,7 +649,7 @@ def test_slice_assign(self): def test_slice_stepped_assign(self): a = torch.ones((10, 4)) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) a[:, 0::2] = 2 xla_a[:, 0::2] = 2 @@ -657,14 +657,14 @@ def test_slice_stepped_assign(self): def test_slice_stepped_other_assign(self): a = torch.ones((10, 4)) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) a[:, 1::4] = 2 xla_a[:, 1::4] = 2 self.assertEqual(a.data, xla_a.data.cpu()) def test_ailing_slice(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.ones((1000, 324)).to(xla_device) xla_a = a.to(xla_device) w = a[:, 2::4] @@ -674,7 +674,7 @@ def test_ailing_slice(self): self.assertEqual(w.data, xla_w.data.cpu()) def test_slice_rnd_stepped_assign(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') size = 10 for s in range(0, size - 1): for e in range(1, size - s): @@ -691,7 +691,7 @@ def test_arange_nan(self): a = torch.arange(float('nan'), 5, device='xla') def test_empty_advanced_indexing(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') base = torch.randn(2, 3, 4, 5) xla_base = base.to(device=xla_device) result = base[:, torch.empty(0, 6, dtype=torch.int64)] @@ -702,7 +702,7 @@ def test_empty_advanced_indexing(self): "grad_input produces wrong results after functionalization. pytorch/pytorch#91199" ) def test_empty_strided(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') m = nn.Conv1d(4, 6, kernel_size=3, groups=2) a = torch.rand(2, 4, 6, requires_grad=True) xla_m = copy.deepcopy(m).to(xla_device) @@ -736,7 +736,7 @@ def test_clamp(self): self.assertEqual(b.data, xla_b.data.cpu()) def test_rrelu_module(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(1, 2, 2, requires_grad=True) xla_a = a.to(xla_device).detach() xla_a.requires_grad = True @@ -753,7 +753,7 @@ def test_rrelu_module(self): self.assertEqual(a.grad, xla_a.grad.cpu()) def test_max_broadcast(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(3, 1, 2) b = torch.rand(4, 2) c = torch.max(a, b) @@ -763,7 +763,7 @@ def test_max_broadcast(self): self.assertEqual(c.data, xla_c.data.cpu()) def test_sgn(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t = torch.randn(2, 3, dtype=torch.cfloat) # Generate inf+infj t[0][0].real.div_(0) @@ -797,7 +797,7 @@ def test_sgn(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_real_c64(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, dtype=torch.cfloat, device=xla_device) real = torch.view_as_real(x) self.assertEqual(real.dtype, torch.float32) @@ -809,7 +809,7 @@ def test_view_as_real_c64(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_real_c128(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, dtype=torch.cdouble, device=xla_device) real = torch.view_as_real(x) self.assertEqual(real.dtype, torch.float64) @@ -821,7 +821,7 @@ def test_view_as_real_c128(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_complex_f32(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, 2, device=xla_device) complex = torch.view_as_complex(x) self.assertEqual(complex.dtype, torch.complex64) @@ -834,7 +834,7 @@ def test_view_as_complex_f32(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_complex_f64(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, 2, dtype=torch.float64, device=xla_device) complex = torch.view_as_complex(x) self.assertEqual(complex.dtype, torch.complex128) @@ -847,7 +847,7 @@ def test_view_as_complex_f64(self): torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) def test_index_put(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) b = torch.rand(4) > 0.1 a[b] = 10 @@ -912,7 +912,7 @@ def test_fn(device): return loss, linear.weight.grad cpu_loss, cpu_weight_grad = test_fn('cpu') - xla_loss, xla_weight_grad = test_fn(torch_xla.device()) + xla_loss, xla_weight_grad = test_fn(torch.device('xla')) self.assertEqual(cpu_loss, xla_loss) self.assertEqual(cpu_weight_grad, xla_weight_grad) @@ -985,7 +985,7 @@ def func(root, b): def test_inplace_view_backprop_view(self): # modify view and backprop through view - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([2., 5.], device=xla_device, requires_grad=False) b = torch.tensor([3.], device=xla_device, requires_grad=True) res = a.narrow(0, 1, 1).mul_(b) @@ -1110,7 +1110,7 @@ def test_replace_xla_tensor(self): self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) def test_pred_type(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) @@ -1132,7 +1132,7 @@ def test_pred_type(self): self.runAtenTest(c, lambda x: x ^ x.byte()) def test_bitwise_and_not(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.randint(255, (4,), dtype=torch.long) xla_a = a.to(xla_device) @@ -1142,27 +1142,27 @@ def test_fn(a): self.runAtenTest(a, test_fn) def test_s_copy_dtype(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(10).to(xla_device).to(dtype=torch.uint8) b = torch.tensor([0, 1, 2, 3]).to(xla_device) self.assertEqual(a[b].dtype, torch.uint8) def test_slice_zero_sized_dim(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') v = torch.randn(2, 3, 4, 5).to(xla_device) y = v[:, :, :, 1] z = y[:, 1:1, :] self.assertEqual(z.size()[1], 0) def test_byte_dtype(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.ByteTensor([0, 1]).to(xla_device) y = torch.ByteTensor([0, 1]).to(xla_device) z = x + y self.assertEqual(z.dtype, torch.uint8) def test_frac_negative(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor(-3.2) b = a.frac() xla_a = a.to(xla_device) @@ -1170,7 +1170,7 @@ def test_frac_negative(self): self.assertEqual(b, xla_b) def test_flip(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) self.assertEqual( torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) @@ -1193,7 +1193,7 @@ def test_flip(self): torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) def test_flip_check_throws(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) # not allow flip on the same dim more than once self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) @@ -1205,7 +1205,7 @@ def test_flip_check_throws(self): self.assertRaises(RuntimeError, lambda: data.flip(3)) def test_flip_expand(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) transposed_data = torch.arange( @@ -1217,7 +1217,7 @@ def test_flip_expand(self): transposed_data.flip(0, 1, 2)) def test_flip_shape(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.randn(2, 3, 4, device=device) size = [2, 3, 4] test_dims = [] @@ -1227,7 +1227,7 @@ def test_flip_shape(self): self.assertEqual(size, list(data.flip(ds).size())) def test_flip_rectangular(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) @@ -1236,13 +1236,13 @@ def test_flip_rectangular(self): self.assertEqual(flip1_result, data.flip(1)) def test_flip_empty_tensor(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([]) self.assertEqual(data, data.flip(0)) def test_norm_p0(self): # p = 0 is equivalent to nonzero - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.randn(3, 2) xla_a = a.to(xla_device) norm = a.norm(p=0) @@ -1288,7 +1288,7 @@ def test_fn(input, src): self.runAtenTest([torch.zeros(3, 3), torch.ones(3)], test_fn) def test_scatter_add_bool(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]]) b = torch.zeros(3, 5, dtype=torch.bool) @@ -1333,7 +1333,7 @@ def test_reduction_0dim(self): self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.mean(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.prod(x)) # min & max throws - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(2, 0, 4) xla_a = a.to(xla_device) self.assertRaises(IndexError, lambda: torch.max(a, dim=1)) @@ -1469,11 +1469,11 @@ def check(device): d = a xm.check_view_sharing([a, d]) - check(torch_xla.device()) + check(torch.device('xla')) check(torch.device('cpu')) def test_save(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1481,7 +1481,7 @@ def test_save(self): self.assertEqual(x, x_loaded) def test_save_bf16(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, dtype=torch.bfloat16, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1489,7 +1489,7 @@ def test_save_bf16(self): self.assertEqual(x, x_loaded) def test_save_tuple(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, device=xla_device) number = 3 with tempfile.NamedTemporaryFile() as tf: @@ -1499,7 +1499,7 @@ def test_save_tuple(self): self.assertEqual(number, number_loaded) def test_save_api(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') model = XlaMNIST().to(xla_device) with tempfile.NamedTemporaryFile() as tf: xm.save(model.state_dict(), tf) @@ -1512,7 +1512,7 @@ def test_save_api(self): def test_serialization_api(self): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - xla_device = torch_xla.device() + xla_device = torch.device('xla') model = XlaMNIST().to(xla_device) xser.save(model.state_dict(), path) state_dict = xser.load(path) @@ -1522,7 +1522,7 @@ def test_serialization_api(self): self.assertEqual(model.state_dict(), loaded_model.state_dict()) def test_deepcopy(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.rand(5, device=xla_device) x0 = x[0] y = copy.deepcopy(x) @@ -1532,7 +1532,7 @@ def test_deepcopy(self): self.assertEqual(x[0], x0) def test_print(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla:0') x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertEqual(str(x), expected_str) @@ -1727,14 +1727,14 @@ def test_fn(t): self.runAtenTest([torch.tensor(20.0)], test_fn) def test_view_and_copy_(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu') y = torch.tensor([0, 0, 0, 0, 0, 0], device=xla_device) y[::2].copy_(x[::2]) self.assertEqual(y, [1, 0, 3, 0, 5, 0]) def test_view_and_multi_sync(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.zeros(100, device=xla_device) t1[10] = 113 torch_xla.sync() @@ -1744,7 +1744,7 @@ def test_view_and_multi_sync(self): torch_xla._XLAC._get_xla_tensors_text([t1])) def test_binaryop_order(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.rand(5, device=xla_device) y = torch.rand(5) self.assertEqual(x + y, y + x) @@ -1759,7 +1759,7 @@ def test_pow_constant(self): assert 'xla::device_data' not in const_hlo def test_emb_bf16(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') index = torch.ones(1, dtype=torch.long, device=xla_device) emb = torch.nn.Embedding(1024, 128, device=xla_device) emb = emb.to(torch.bfloat16) @@ -1779,7 +1779,7 @@ def test_on_device(device): return m(index) out = test_on_device("cpu") - out_x = test_on_device(torch_xla.device()) + out_x = test_on_device(torch.device('xla')) self.assertEqual(out, out_x.cpu()) def test_transpose_1d(self): @@ -1798,7 +1798,7 @@ def test_fn(t1): def test_sigmoid_bounds(self): torch.manual_seed(0) - xla_device = torch_xla.device() + xla_device = torch.device('xla') for _ in range(100): x = torch.rand(1000).to(xla_device) lower_bound = torch.sigmoid(x * (-100.0)) @@ -1807,7 +1807,7 @@ def test_sigmoid_bounds(self): assert torch.all(upper_bound <= 1.0) def test_manual_seed(self): - device = torch_xla.device() + device = torch.device('xla') torch_xla.manual_seed(12345) t1 = torch.randn(5, 5, device=device) torch_xla.manual_seed(12345) @@ -1815,7 +1815,7 @@ def test_manual_seed(self): self.assertTrue(torch.allclose(t1.cpu(), t2.cpu())) def test_cached_addcdiv(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(1, 3).to(xla_device) @@ -1833,7 +1833,7 @@ def test_cached_addcdiv(self): @skipOnEagerDebug def test_print_execution(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') torch_xla.sync() xm.wait_device_ops() met.clear_all() @@ -1887,7 +1887,7 @@ def test_fn(input): return dropped[1].cpu(), input.grad.cpu() met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') input_cpu = torch.randn(7, 7, requires_grad=True) input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True) mask_cpu, grad_cpu = test_fn(input_cpu) @@ -2045,7 +2045,7 @@ def foo(x): x = torch.arange(10).to(dtype) r = foo(x) - device = torch_xla.device() + device = torch.device('xla') Xx = x.to(device) Xr = foo(Xx) @@ -2074,7 +2074,7 @@ def func(input_volume): return F.interpolate( input_volume, size=output_size, mode='trilinear', align_corners=False) - device = torch_xla.device() + device = torch.device('xla') input_volume = torch.randn(1, 3, 16, 32, 32).to(device) met.clear_all() self.runAtenTest((input_volume), func) @@ -2105,7 +2105,7 @@ def foo(t): t.retain_grad() t.grad = torch.rand(10, 10, dtype=torch.bfloat16) xt = t.to('xla') - xt.grad = t.grad.to(torch_xla.device(), dtype=torch.bfloat16) + xt.grad = t.grad.to(torch.device('xla'), dtype=torch.bfloat16) foo(t) foo(xt) @@ -2393,7 +2393,7 @@ def run(device): return runf(*args_) actual = run("cpu") - expected = run(torch_xla.device()) + expected = run(torch.device('xla')) self.assertFalse( met.executed_fallback_ops(), msg="expected no fallback operations.") @@ -2452,7 +2452,7 @@ class TestModelComparator(test_utils.XlaTestCase): def test(self): SEED = 42 - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = _gen_tensor(8, 1, 28, 28) xla_x = x.to(xla_device) @@ -2477,7 +2477,7 @@ def test(self): class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): - torch_xla.device() + torch.device('xla') value = torch.randn(10000, 10000, device='xla') val_list = [] val_mean_list = [] @@ -2496,7 +2496,7 @@ class TestDebuggingUtil(test_utils.XlaTestCase): @skipOnEagerDebug def test_get_xla_tensor_debug_info(self): - device = torch_xla.device() + device = torch.device('xla') # test non xla tensor cpu_t1 = torch.randn(5) cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1) @@ -2531,7 +2531,7 @@ def runOpBuilderTest(self, kwargs=dict()): op = xor.register(name, opfn) if device is None: - device = torch_xla.device() + device = torch.device('xla') if aten_fn is None: aten_fn = opfn tensors = xu.as_list(tensors) @@ -2653,7 +2653,7 @@ class MpDecoratorTest(test_utils.XlaTestCase): @xtu.mp_test def test_mp_decorator(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') self.assertTrue(xla_device.type == 'xla') @@ -2692,7 +2692,7 @@ class TestLoweringContext(test_utils.XlaTestCase): def test_api(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') a = torch.tensor([1.0, 2.0, 3.0], device=device) b = torch.tensor([4.0, 5.0, 6.0], device=device) @@ -2720,7 +2720,7 @@ def test_get_parameters_scalar(self): that appropriately. """ - device = torch_xla.device() + device = torch.device('xla') tensors = [] for i in range(10): # Add three copies of the same value. @@ -2753,13 +2753,13 @@ def test_git_revisons(self): self.assertTrue('torch' in revs) def test_send_to_device_grad(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t = _gen_tensor(2, 2, requires_grad=True) dt = xm.send_cpu_data_to_device([t], xla_device) self.assertTrue(dt[0].requires_grad) def test_send_to_device_single(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla:0') t = _gen_tensor(2, 2) dt = xm.send_cpu_data_to_device(t, xla_device) self.assertEqual(dt[0].device, xla_device) @@ -2859,7 +2859,7 @@ def from_tensors(self, tensors): wpack = PackWrapper(pack) - xla_device = torch_xla.device() + xla_device = torch.device('xla:0') xdata = xm.send_cpu_data_to_device(wpack, xla_device) self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence)) self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) @@ -2869,7 +2869,7 @@ def from_tensors(self, tensors): "https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008") def test_as_strided_input_larger(self): size = (5, 5) - device = torch_xla.device() + device = torch.device('xla') a = torch.ones(size, device=device) small_a = a[:, ::2] @@ -2899,7 +2899,7 @@ def test_aten_move_scalar_cuda_to_xla(self): self._test_move_tensor_cuda_to_xla(torch.tensor(42)) def test_unsafe_buffer_pointer(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_tensor_0 = torch.tensor(42).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2944,7 +2944,7 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) def test_dlpack_roundtrip_tensor(self, dtype): - xla_device = torch_xla.device() + xla_device = torch.device('xla') # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr # xla_tensor_2 uses XLANativeFunctions::_to_copy xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) @@ -2961,7 +2961,7 @@ def test_dlpack_roundtrip_tensor(self, dtype): *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -3118,7 +3118,7 @@ def forward(self, inp): class TestActivationCheckpoint(test_utils.XlaTestCase): def test_dropout(self): - device = torch_xla.device() + device = torch.device('xla') model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3132,7 +3132,7 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") def test_opt_barrier(self): - device = torch_xla.device() + device = torch.device('xla') model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3167,7 +3167,7 @@ def _reference_nms(self, boxes, scores, iou_threshold): def _nms(self, boxes, scores, iou_threshold): import torchvision - device = torch_xla.device() + device = torch.device('xla') return torchvision.ops.nms( boxes.to(device), scores.to(device), iou_threshold).cpu() @@ -3240,7 +3240,7 @@ class TestHelperFunction(test_utils.XlaTestCase): def test_repeat_truncated(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 20 input = torch.randn(10).to(device) repeats = torch.tensor([0, 1, 2, 0, 4, 0, 6, 7, 8, 9]).to(device) @@ -3253,7 +3253,7 @@ def test_repeat_truncated(self): def test_repeat_extended(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 100 input = torch.randn(10).to(device) repeats = torch.tensor([0, 5, 2, 0, 4, 9, 6, 7, 8, 0]).to(device) @@ -3271,7 +3271,7 @@ def test_repeat_extended(self): def test_repeat_special(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 135 num_groups = 8 input = torch.arange(num_groups, dtype=torch.int32).to(device) diff --git a/test/test_placeholder.py b/test/test_placeholder.py index d5506bfacd55..5b6c2096a39e 100644 --- a/test/test_placeholder.py +++ b/test/test_placeholder.py @@ -19,7 +19,7 @@ def test_create_placeholder(self): ): p = create_placeholder_tensor(shape, dtype) assert isinstance(p, torch.Tensor) - assert p.device == torch_xla.device() + assert p.device == torch.device('xla') self.assertEqual(p.dtype, dtype) self.assertEqual(p.shape, shape) self.assertTrue(torch_xla._XLAC._is_placecholder(p)) @@ -56,7 +56,7 @@ def test_placeholder_handle_unique(self): self.assertNotEqual(h1, h2) def test_cannot_get_handle_from_deleted_pjrt_buffer(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index e23c2f59c223..2dbd67655918 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -144,7 +144,7 @@ def train_mnist(flags, # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) writer = None if xm.is_master_ordinal(): diff --git a/test/test_python_ops.py b/test/test_python_ops.py index 9dc145947f62..557bf5c4c278 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -29,8 +29,8 @@ def test_put(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = torch_xla.device() - real_device_type = xm.xla_device_hw(str(torch_xla.device())) + device = torch.device('xla') + real_device_type = xm.xla_device_hw(str(torch.device('xla'))) if real_device_type == "TPU": raise unittest.SkipTest("TestPut is too slow on TPU. Skipped") @@ -108,7 +108,7 @@ def test_index_copy(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = torch_xla.device() + device = torch.device('xla') # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined diff --git a/test/test_syncfree_optimizers.py b/test/test_syncfree_optimizers.py index 8807271440c6..593ea06b83f1 100644 --- a/test/test_syncfree_optimizers.py +++ b/test/test_syncfree_optimizers.py @@ -53,7 +53,7 @@ def _test_optimizer(self, syncfree_optim_cls, ref_optim_cls, optim_kwargs={'lr': 1e-2}): - device = torch_xla.device() + device = torch.device('xla') loss_fn = nn.NLLLoss() # syncfree model torch.manual_seed(0) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 98730dbf7009..898c249625e2 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -7,7 +7,7 @@ def _mp_fn(index): - dev = torch_xla.device() + dev = torch.device('xla') if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU or CUDA device'.format(dev), diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index a3069a6637ec..54626415255f 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -63,7 +63,7 @@ def test_xla_backend_exists(self): self.assertIsNotNone(pg_xla_creator) def test_allreduce(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.all_reduce(tensor) @@ -72,7 +72,7 @@ def test_allreduce(self): @patch_world(rank=3, size=6) def test_allreduce_with_mesh(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_options = {'xla_pg_options': {'spmd': True}} @@ -89,7 +89,7 @@ def test_allreduce_with_mesh(self): @patch_world(rank=3, size=8) def test_allgather(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -99,7 +99,7 @@ def test_allgather(self): @patch_world(rank=3, size=8) def test_all_scalar_allgather(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -109,7 +109,7 @@ def test_all_scalar_allgather(self): @patch_world(rank=3, size=8) def test_allgather_coalesced(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=3, size=8) @@ -127,7 +127,7 @@ def test_allgather_coalesced(self): hlo_matches(hlo, all_gather_pattern) def test_broadcast(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.broadcast(tensor, 0) @@ -136,7 +136,7 @@ def test_broadcast(self): # Needed for ZeRO stage 1 def test_reduce_scatter(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] output = torch.zeros_like(tensor) @@ -148,7 +148,7 @@ def test_reduce_scatter(self): @skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") def test_reduce_scatter_coalesced(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] @@ -168,7 +168,7 @@ def test_reduce_scatter_coalesced(self): @patch_world(0, 6) def test_send(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] @@ -185,11 +185,11 @@ def test_send(self): hlo_matches(hlo, senddone_pattern) # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) @patch_world(0, 6) def test_recv(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() with mock.patch.object( @@ -205,7 +205,7 @@ def test_recv(self): hlo_matches(hlo, recvdone_pattern) # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): @@ -365,7 +365,7 @@ def test_barrier(self): 'monitored_barrier', ) def test_unimplemented_op(self, op): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = dist.group.WORLD self.assertIsInstance(pg_xla, diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index efb34a2cc3af..0a031c1d0cba 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -250,7 +250,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 290857281fd7..c5bf26b9e4cf 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -194,7 +194,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None @@ -229,7 +229,7 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index 1d939d8385b3..3423b3e4df59 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -241,7 +241,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 0a5e46fdcd1f..4aa328752e89 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 0bd393b21f2e..d6fac172003a 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') device_hw = xm.xla_device_hw(device) model = MNIST().to(device) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 833612a2be49..c6aa20bc1d68 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -164,7 +164,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 523bf5fc0a19..11926c273697 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -114,7 +114,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) writer = None diff --git a/test/test_user_computation_debug_cache.py b/test/test_user_computation_debug_cache.py index f83f856c2cfd..a6fb1cd885ae 100644 --- a/test/test_user_computation_debug_cache.py +++ b/test/test_user_computation_debug_cache.py @@ -40,7 +40,7 @@ def input_scope_0(tensor): def input_scope_1(tensor): return [torch.sin(tensor), torch.cos(tensor)] - device = torch_xla.device() + device = torch.device('xla') init_tensor = torch.tensor(10).to(device) def create_user_computation(fn): diff --git a/test/test_utils.py b/test/test_utils.py index 6a913f932e4d..2bbf7255182c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -384,7 +384,7 @@ def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5): def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: - device = torch_xla.device() + device = torch.device('xla') tensors = xu.as_list(tensors) xla_tensors = [ x.to(device).detach().requires_grad_(x.requires_grad) for x in tensors diff --git a/test/test_while_loop.py b/test/test_while_loop.py index 4dc0a17a96ea..d58b18eb3e45 100644 --- a/test/test_while_loop.py +++ b/test/test_while_loop.py @@ -26,7 +26,7 @@ def _fake_while_loop(cond_fn, body_fn, operands): class WhileLoopTest(unittest.TestCase): def test_while_loop_addition(self): - device = torch_xla.device() + device = torch.device('xla') def cond_fn(iteri, x): return iteri > 0 @@ -41,7 +41,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_addition_nested(self): - device = torch_xla.device() + device = torch.device('xla') def cond_fn(iteri, x): return iteri > 0 @@ -56,7 +56,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_simple_linear_inside_loop(self): - device = torch_xla.device() + device = torch.device('xla') torch.set_grad_enabled(False) class SimpleLinear(torch.nn.Module): @@ -94,7 +94,7 @@ def forward_without_while_loop_op(self, iteri, x): # ====== fori_loop ====== @unittest.skip("Fori_loop is not supported now due to unstable result.") def test_fori_loop_addition(self): - device = torch_xla.device() + device = torch.device('xla') lower = torch.tensor(0, device=device) upper = torch.tensor(50, device=device) diff --git a/test/test_zero1.py b/test/test_zero1.py index 8bb2fbc3d822..1a798abc1d9c 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -34,7 +34,7 @@ class XlaZeRO1Test(test_utils.XlaTestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") def test_zero1(self): - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -89,7 +89,7 @@ def test_zero1(self): torch_xla.sync() def test_zero1_load(self): - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -153,7 +153,7 @@ def test_zero1_load(self): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index 1d91f520d5aa..6e8c01a3f7b9 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -24,7 +24,7 @@ def _ddp_correctness(rank, gradient_as_bucket_view: bool = False): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU device'.format(device), diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index 7c30b211ad49..125121b8c798 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index 2fd71d2ed84e..18bec4fecdc0 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index c462f7552800..82eff827fc9f 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 444c47890330..2f382eb86246 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -141,7 +141,7 @@ def meta_module_fn(): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') # This test fails on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) if xm.xla_device_hw(device) in ('TPU', 'NEURON'): dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 9089f9d799ff..affc32c4a73d 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py index 006d3fd33a95..90ccbfb64d0c 100644 --- a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/utils/train_spmd_linear_model.py b/test/utils/train_spmd_linear_model.py index ac0dd9f86b22..4407bd665f65 100644 --- a/test/utils/train_spmd_linear_model.py +++ b/test/utils/train_spmd_linear_model.py @@ -69,7 +69,7 @@ def forward(self, x): def train(): - device = torch_xla.device() + device = torch.device('xla') torch.manual_seed(42) model = SimpleLinear().to(device) print('===> Preparing data..') diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py index 62f3e79ae4a0..2d6ccfd71a3c 100644 --- a/test/utils/train_spmd_linear_model_grad_acc.py +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -77,7 +77,7 @@ def forward(self, x): def train(): - device = torch_xla.device() + device = torch.device('xla') num_devices = xr.global_runtime_device_count() print(f'num_devices: {num_devices}') # Define a mesh with all devices along one axis diff --git a/torch_xla/_dynamo/dynamo_backend2.py b/torch_xla/_dynamo/dynamo_backend2.py index e3fee43f792b..1d515c9cc63f 100644 --- a/torch_xla/_dynamo/dynamo_backend2.py +++ b/torch_xla/_dynamo/dynamo_backend2.py @@ -34,7 +34,7 @@ def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any): jax.config.update("jax_enable_x64", True) env = torchax.default_env() - xla_device = torch_xla.device() + xla_device = torch.device('xla') def run_jax(*args, initial_rng_key): args_t = torchax.interop.torch_view(args) diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 7cae4f7392e5..1061406746f7 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -498,7 +498,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # 2. All of the pending IRs are result of our warm up cache tracing and they # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, @@ -567,7 +567,7 @@ def optimized_mod(*args: tuple): is_cuda_args = original_device.type == "cuda" if is_cuda_args: - args = _maybe_move_tensors_to_device(args, torch_xla.device()) + args = _maybe_move_tensors_to_device(args, torch.device('xla')) if not config.skip_input_data_check: # `torch_xla.sync()` needs to be blocking since we want to access args's @@ -768,7 +768,7 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, # UnsupportedNodesCollector might trigger in place ops, need to clear them here. _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): @@ -813,7 +813,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): xla_args = tuple( - _maybe_move_tensors_to_device(xla_args, torch_xla.device())) + _maybe_move_tensors_to_device(xla_args, torch.device('xla'))) # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d333..6d437ecab5ee 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -118,7 +118,7 @@ def master_print(*args: Any, print(*args, file=fd, flush=flush) -@deprecated("Use torch_xla.device instead") +@deprecated("Use torch.device('xla') instead") def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: """Returns a given instance of an XLA device. diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py index 62943f4c70c5..2e82d49028d4 100644 --- a/torch_xla/core/xla_op_registry.py +++ b/torch_xla/core/xla_op_registry.py @@ -68,7 +68,7 @@ def slice_and_add(a, b, dimno=0): SLICE_AND_ADD = xor.register('slice_and_add', slice_and_add) def user_computation_test(): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 2).to(device) y = torch.randn(2, 2).to(device) z = SLICE_AND_ADD(x, y, dimno=0) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 68bb7ea7a48e..feda7894c081 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1310,7 +1310,7 @@ void BuildLoweringContextSubmodule(py::module* m) { * import torch_xla * import torch_xla.core.xla_model as xm * - * device = torch_xla.device() + * device = torch.device('xla') * example = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) * * def network(x): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index c5605d2b3ed2..fb5e41cc92c5 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -139,7 +139,7 @@ class XlaFullyShardedDataParallel(nn.Module): module (nn.Module): module to be wrapped with FSDP. If the input module's parameters and buffers are not already on XLA device, they will be cast to - ``torch_xla.device()`` (after sharding) during FSDP initialization. + ``torch.device('xla')`` (after sharding) during FSDP initialization. reshard_after_forward (bool, Optional): if ``True``, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding @@ -527,7 +527,7 @@ def __init__( List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params - self.xla_device = torch_xla.device() + self.xla_device = torch.device('xla') # Shard module parameters in place self._shard_parameters_(params_to_shard) # Cast the module buffers to the specified buffer_dtype @@ -1646,7 +1646,7 @@ def _print_r0(self, msg: str, restart: bool = False) -> None: if restart: self._tstart = time.time() if self.rank == 0: - memory_info = xm.get_memory_info(torch_xla.device()) + memory_info = xm.get_memory_info(torch.device('xla')) gb_free = memory_info["kb_free"] / 1024 / 1024 gb_total = memory_info["kb_total"] / 1024 / 1024 logging.info( diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 05a37fe9b411..f84b71d32f9d 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -265,7 +265,7 @@ class MpDeviceLoader(object): Example: - >>> device = torch_xla.device() + >>> device = torch.device('xla') >>> train_device_loader = MpDeviceLoader(train_loader, device) """ diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index 567fba1ad015..29c930af5d9a 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -216,7 +216,7 @@ def xla_distribute_module( if partition_fn: if getattr(partition_fn, '__name__', 'unknown') == "auto_policy": # TODO(yeounoh) allow pre-loading to xla device in the future. - assert next(module.parameters()).device != torch_xla.device(), \ + assert next(module.parameters()).device != torch.device('xla'), \ f"Currently requires module to be on cpu, before xla_distribute_module." xr.use_spmd(auto=True) module = module.to('xla') diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cffe..239a2bce1043 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xm.xla_device_hw(torch.device('xla')) == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) @@ -826,7 +826,7 @@ def can_apply(self, t: torch.Tensor) -> bool: def apply(self, t: torch.Tensor): # TODO(yeounoh) use virtual device interface when available. - assert (t.device == torch_xla.device()) + assert (t.device == torch.device('xla')) mark_sharding(t, self.mesh, self.partition_spec) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index e3b349a4b7fb..b14fde5bb1a8 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -56,7 +56,7 @@ class MpModelWrapper(object): WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): - device = torch_xla.device() + device = torch.device('xla') model = WRAPPED_MODEL.to(device) ... diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ea4c8d54c1a2..f0100dec87bd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,7 +16,7 @@ def fori_loop(lower, upper, body_fun, *input_value): - device = torch_xla.device() + device = torch.device('xla') if (upper < lower): print("ERROR: upper should be a larger number than lower") iteri = upper - lower diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index 4e3f8682e68e..0855fffbd62a 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -154,7 +154,7 @@ def _make_init_grad(param): def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params, carried_tensors): builder = XlaBuildHelper('grad_acc') - device = torch_xla.device() + device = torch.device('xla') init_iterator = torch.tensor(0, dtype=torch.int32, device=device) init_loss = torch.tensor(0, dtype=torch.float32, device=device) diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 565b569ed726..ee8fabe5a3ac 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -143,7 +143,7 @@ def scan(fn, init, xs): >>> y = new_carry >>> return new_carry, y >>> - >>> with torch_xla.device(): + >>> with torch.device('xla'): >>> init = torch.tensor([0.0, 0.0], requires_grad=True) >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], >>> requires_grad=True) @@ -650,7 +650,7 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: t = xb.create_placeholder_tensor(v.shape, v.dtype) return t.requires_grad_(v.requires_grad) - device = torch_xla.device() + device = torch.device('xla') fake_carry = tree_map(make_fake_tensor, init) fake_x = tree_map(lambda v: make_fake_tensor(v[0]), xs) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 4e55111caeec..3bbd78196fd5 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -50,7 +50,7 @@ def scan_layers(layers: Iterable[torch.nn.Module], >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers - >>> with torch_xla.device(): + >>> with torch.device('xla'): >>> layers = [nn.Linear(16, 16) for i in range(10)] >>> input = torch.randn(16) >>> output = scan_layers(layers, input) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 2e274190db75..e5aef103a17a 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -156,7 +156,8 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + torch_xla.device().index + return local_rank * devices_per_process + torch.device( + torch_xla._XLAC._xla_get_default_device()).index def process_index() -> int: diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index b88a8131b2d8..1a70f7972af6 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -341,7 +341,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." - device = torch_xla.device() + device = torch.device('xla') _flat_input_args = exported_model._graph_module_flat_inputs(args, {}) _flat_input_args = pytree.tree_map_only(torch.Tensor, @@ -352,7 +352,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, torch_xla.sync() xm.wait_device_ops() metrics.clear_counters() - device = torch_xla.device() + device = torch.device('xla') # Run the fx graph tracing using lazy tensor if options.inline_all_constant: diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 9062d6a9ef21..739b1147b28a 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -4,6 +4,7 @@ import functools import uuid from typing import Any, Callable, List, Optional, Tuple +from typing_extensions import deprecated import weakref import torch @@ -16,6 +17,7 @@ import torch_xla.utils.utils as xu +@deprecated("Use torch.device('xla') instead") def device(index: int = None) -> torch.device: """Returns a given instance of an XLA device.