diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 46acd0c1cdca3..dfd6dce40b999 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -22,7 +22,10 @@ ) from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_BF16, + PLATFORM_SUPPORTS_FLASH_ATTENTION, +) from torch.testing._internal.common_device_type import ( flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, @@ -1582,6 +1585,7 @@ def mask_mod(b, h, q, kv): self.assertEqual(out[:, :, M:, :].sum(), 0) @supported_platform + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_windowed_no_mask_vs_sdpa(self, device): score_mod = _generate_windowed(1000) attention = functools.partial(flex_attention, score_mod=score_mod) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index d1c5637d82008..eba9fbaf131c5 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -37,11 +37,15 @@ ) from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 +from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, TEST_WITH_ROCM, + MI300_ARCH, + runOnRocmArch, + skipIfXpu, ) from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.utils._triton import has_triton_tma_device @@ -54,7 +58,6 @@ from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu from torch.testing._internal.inductor_utils import ( get_func_call, get_kernel_launch, @@ -804,6 +807,8 @@ def test_conv_backend(self): self.assertIn("NoValidChoicesError", str(context.exception)) + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @largeTensorTest("30 GB", device=GPU_TYPE) def test_non_contiguous_input_mm(self): """ Make sure the triton template can work with non-contiguous inputs without crash. @@ -856,6 +861,8 @@ def f(x, y): # TODO: fix accuracy failure of the triton template on XPU. # and enable this test case. @skipIfXpu + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @largeTensorTest("30 GB", device=GPU_TYPE) def test_non_contiguous_input_mm_plus_mm(self): x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)