diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f7fd0dab128e4..2726344866a41 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1894,9 +1894,26 @@ def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: +<<<<<<< HEAD prop = torch.cuda.get_device_properties(0) if prop.gcnArchName.split(":")[0] in arch: reason = f"skipIfRocm: test skipped on {arch}" +======= + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + + total = props.total_memory / (1024 ** 3) # in GB + # This will probably return 0 because it only counts tensors + # and doesn't take into account any small supporting allocations + allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) + free_global = total - allocated + + result = free_global > required_amount + + if not result: + reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \ + f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB." +>>>>>>> f78730679a1 (Formatting code style) raise unittest.SkipTest(reason) return fn(self, *args, **kwargs) return wrap_fn