diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5bf14ae094e0..e797444d0eb0 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -66,6 +66,8 @@ TEST_WITH_DEV_DBG_ASAN, TEST_WITH_ROCM, TestCase, + is_arch, + NAVI_ARCH, ) @@ -625,6 +627,12 @@ def _helper_test_extra_cuda_context_by_memory(self): # Rank 0 takes a snapshot before collective -- this snapshot should have # included rank 0's own context. if self.rank == 0: + # We need this extra sleep for NAVI_ARCH because rccl_init inside init_process_group + # is happening in a separate process and it is taking longer to finish on NAVI_ARCH. + # Sleeping here ensures that the init is competed successfully and mem_get_info can + # get stable numbers. + if is_arch(NAVI_ARCH): + time.sleep(5) free, total = torch.cuda.mem_get_info(device) used_before = float(total - free) diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index a9f3daa67528..587214c5b735 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -12,8 +12,13 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, +<<<<<<< HEAD patch_test_members, is_navi3_arch, +======= + NAVI3_ARCH, + is_arch, +>>>>>>> 71a21d9cde ([rocm6.4_internal_testing][SWDEV-535305] Fixed `test_extra_cuda_context` in `test_c10d_nccl.py` and refactored is_navi3_arch function (#2341)) parametrize, TEST_XPU, ) @@ -49,6 +54,14 @@ def forward(self, input1, input2): output = torch.mm(input1, input2) return output +<<<<<<< HEAD +======= +# We have to increase tolerance for navi3 because all fp16, bf16 +# GEMMs operations have an accuracy issue caused by hardware limitation +default_atol = 3e-3 if is_arch(NAVI3_ARCH) else 1e-3 +default_rtol = 4e-3 if is_arch(NAVI3_ARCH) else 1e-3 + +>>>>>>> 71a21d9cde ([rocm6.4_internal_testing][SWDEV-535305] Fixed `test_extra_cuda_context` in `test_c10d_nccl.py` and refactored is_navi3_arch function (#2341)) @requires_gpu @unittest.skipIf( diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f7fd0dab128e..5ee098bd013d 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -108,11 +108,11 @@ NAVI3_ARCH = ("gfx1100", "gfx1101") NAVI4_ARCH = ("gfx1200", "gfx1201") -def is_navi3_arch(): +def is_arch(arch_list): if torch.cuda.is_available(): prop = torch.cuda.get_device_properties(0) gfx_arch = prop.gcnArchName.split(":")[0] - if gfx_arch in NAVI3_ARCH: + if gfx_arch in arch_list: return True return False