diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index 114c41e5c82..36bb87b3876 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -44,7 +44,7 @@ def forward(self, x): @unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', + xr.device_type() != 'TPU', f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail." ) class TestDynamicShapeModels(unittest.TestCase): diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index f199797afc2..151207f6bc1 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -20,7 +20,7 @@ def setUpClass(cls): @staticmethod def _assert_tpus_exist(index=0): del index - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xr.device_type() == 'TPU' def test_single_process(self): with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: diff --git a/test/test_autocast.py b/test/test_autocast.py index ca1f26c05ec..19101e27659 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -348,8 +348,7 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', f"TPU autocast test.") +@unittest.skipIf(xr.device_type() != 'TPU', f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): @classmethod @@ -405,7 +404,7 @@ class TestOtherOps(unittest.TestCase): # On TPU, the input of batch norm is casted into fp32, see torch_xla/csrc/autocast_mode.cpp @unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', + xr.device_type() != 'TPU', "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): device = torch_xla.device() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cff..d6c2ef57c3e 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xr.device_type() == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 2e274190db7..31465888098 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -82,12 +82,12 @@ def _maybe_select_default_device(): def device_type() -> Optional[str]: - """Returns the current PjRt device type. + """Returns the current PJRT device type. Selects a default device if none has been configured Returns: - A string representation of the device. + A string representation of the PJRT device: "CPU", "TPU", etc. """ pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) return pjrt_device.split('_')[0] if pjrt_device else pjrt_device