diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 1326a9acb61e..a75484be3d87 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -67,6 +67,8 @@ TEST_WITH_DEV_DBG_ASAN, TEST_WITH_ROCM, TestCase, + is_arch, + NAVI_ARCH, ) from torch.utils.cpp_extension import load_inline @@ -590,6 +592,12 @@ def _helper_test_extra_cuda_context_by_memory(self): # Rank 0 takes a snapshot before collective -- this snapshot should have # included rank 0's own context. if self.rank == 0: + # We need this extra sleep for NAVI_ARCH because rccl_init inside init_process_group + # is happening in a separate process and it is taking longer to finish on NAVI_ARCH. + # Sleeping here ensures that the init is competed successfully and mem_get_info can + # get stable numbers. + if is_arch(NAVI_ARCH): + time.sleep(5) free, total = torch.cuda.mem_get_info(device) used_before = float(total - free) diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 33104ce792d4..587e178518a9 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -11,7 +11,8 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - is_navi3_arch, + NAVI3_ARCH, + is_arch, parametrize, skipIfXpu, ) @@ -49,8 +50,8 @@ def forward(self, input1, input2): # We have to increase tolerance for navi3 because all fp16, bf16 # GEMMs operations have an accuracy issue caused by hardware limitation -default_atol = 3e-3 if is_navi3_arch() else 1e-3 -default_rtol = 4e-3 if is_navi3_arch() else 1e-3 +default_atol = 3e-3 if is_arch(NAVI3_ARCH) else 1e-3 +default_rtol = 4e-3 if is_arch(NAVI3_ARCH) else 1e-3 @requires_gpu diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 74ad5c349c1d..51236a794376 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -112,11 +112,11 @@ NAVI3_ARCH = ("gfx1100", "gfx1101") NAVI4_ARCH = ("gfx1200", "gfx1201") -def is_navi3_arch(): +def is_arch(arch_list): if torch.cuda.is_available(): prop = torch.cuda.get_device_properties(0) gfx_arch = prop.gcnArchName.split(":")[0] - if gfx_arch in NAVI3_ARCH: + if gfx_arch in arch_list: return True return False