diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index b5ec59dc291c..6b34c19431f1 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -22,7 +22,11 @@ ) from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf32_off +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_BF16, + with_tf32_off, +) from torch.testing._internal.common_device_type import ( flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, @@ -1591,6 +1595,7 @@ def mask_mod(b, h, q, kv): self.assertEqual(out[:, :, M:, :].sum(), 0) @supported_platform + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_windowed_no_mask_vs_sdpa(self, device): score_mod = _generate_windowed(1000) attention = functools.partial(flex_attention, score_mod=score_mod) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index cb7b2a513ede..3c2fd6add14b 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -45,8 +45,7 @@ IS_WINDOWS, parametrize, TEST_WITH_ROCM, - NAVI_ARCH, - skipIfRocmArch, + skipIfRocmNotEnoughMemory, ) from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.utils._triton import has_triton_tma_device @@ -819,6 +818,8 @@ def test_conv_backend(self): self.assertIn("NoValidChoicesError", str(context.exception)) + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @skipIfRocmNotEnoughMemory(30) def test_non_contiguous_input_mm(self): """ Make sure the triton template can work with non-contiguous inputs without crash. @@ -837,7 +838,6 @@ def f(x, y): act = f(x, y) torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) - @skipIfRocmArch(NAVI_ARCH) def test_non_contiguous_input_addmm(self): b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE) x = rand_strided( @@ -872,6 +872,8 @@ def f(x, y): # TODO: fix accuracy failure of the triton template on XPU. # and enable this test case. @skipIfXpu + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @skipIfRocmNotEnoughMemory(30) def test_non_contiguous_input_mm_plus_mm(self): x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9a2a393994ef..2fa368bb2d89 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1943,6 +1943,31 @@ def wrap_fn(self, *args, **kwargs): return wrap_fn return dec_fn +# Checks if current ROCm device has enough VRAM against the required amount in GB +def skipIfRocmNotEnoughMemory(required_amount): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + + total = props.total_memory / (1024 ** 3) # in GB + # This will probably return 0 because it only counts tensors + # and doesn't take into account any small supporting allocations + allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) + free_global = total - allocated + + result = free_global > required_amount + + if not result: + reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \ + f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB." + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + def runOnRocm(fn): @wraps(fn) def wrapper(*args, **kwargs):