diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 207a12dbff8d..388f074b3757 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -1332,6 +1332,7 @@ def mask_mod(b, h, q, kv): self.assertEqual(query.grad[:, :, M:, :].sum(), 0) @supported_platform + @skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_windowed_no_mask_vs_sdpa(self): score_mod = _generate_windowed(1000) attention = functools.partial(flex_attention, score_mod=score_mod) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 494d0cf89082..2c775232aeed 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocmNotEnoughMemory, skipIfRocm, TEST_WITH_ROCM, ) @@ -717,6 +718,8 @@ def test_conv_backend(self): self.assertIn("NoValidChoicesError", str(context.exception)) + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @skipIfRocmNotEnoughMemory(30) def test_non_contiguous_input_mm(self): """ Make sure the triton template can work with non-contiguous inputs without crash. @@ -766,6 +769,8 @@ def f(x, y): act = f(x, y) torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @skipIfRocmNotEnoughMemory(30) def test_non_contiguous_input_mm_plus_mm(self): x1 = rand_strided((50257, 32768), (1, 50304), device="cuda") y1 = rand_strided((32768, 768), (768, 1), device="cuda") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 74ad5c349c1d..35a839f11524 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1872,14 +1872,26 @@ def wrap_fn(self, *args, **kwargs): return wrap_fn return dec_fn -def skipIfRocmArch(arch: tuple[str, ...]): +# Checks if current ROCm device has enough VRAM against the required amount in GB +def skipIfRocmNotEnoughMemory(required_amount): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: - prop = torch.cuda.get_device_properties(0) - if prop.gcnArchName.split(":")[0] in arch: - reason = f"skipIfRocm: test skipped on {arch}" + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + + total = props.total_memory / (1024 ** 3) # in GB + # This will probably return 0 because it only counts tensors + # and doesn't take into account any small supporting allocations + allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) + free_global = total - allocated + + result = free_global > required_amount + + if not result: + reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \ + f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB." raise unittest.SkipTest(reason) return fn(self, *args, **kwargs) return wrap_fn