Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/inductor/test_flex_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,7 @@ def mask_mod(b, h, q, kv):
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)

@supported_platform
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
def test_windowed_no_mask_vs_sdpa(self):
score_mod = _generate_windowed(1000)
attention = functools.partial(flex_attention, score_mod=score_mod)
Expand Down
5 changes: 5 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocmNotEnoughMemory,
skipIfRocm,
TEST_WITH_ROCM,
)
Expand Down Expand Up @@ -717,6 +718,8 @@ def test_conv_backend(self):

self.assertIn("NoValidChoicesError", str(context.exception))

# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
@skipIfRocmNotEnoughMemory(30)
def test_non_contiguous_input_mm(self):
"""
Make sure the triton template can work with non-contiguous inputs without crash.
Expand Down Expand Up @@ -766,6 +769,8 @@ def f(x, y):
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)

# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
@skipIfRocmNotEnoughMemory(30)
def test_non_contiguous_input_mm_plus_mm(self):
x1 = rand_strided((50257, 32768), (1, 50304), device="cuda")
y1 = rand_strided((32768, 768), (768, 1), device="cuda")
Expand Down
20 changes: 16 additions & 4 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,14 +1872,26 @@ def wrap_fn(self, *args, **kwargs):
return wrap_fn
return dec_fn

def skipIfRocmArch(arch: tuple[str, ...]):
# Checks if current ROCm device has enough VRAM against the required amount in GB
def skipIfRocmNotEnoughMemory(required_amount):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
reason = f"skipIfRocm: test skipped on {arch}"
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)

total = props.total_memory / (1024 ** 3) # in GB
# This will probably return 0 because it only counts tensors
# and doesn't take into account any small supporting allocations
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
free_global = total - allocated

result = free_global > required_amount

if not result:
reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \
f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB."
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
Expand Down