Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
#ifndef USE_ROCM
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
#endif
if (bias) {
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
Expand Down
2 changes: 1 addition & 1 deletion functorch/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# PyTorch forward-mode is not mature yet
from functorch import functionalize
from torch._functorch.deprecated import functionalize
from torch._functorch.apis import chunk_vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
12 changes: 11 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
)
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
TEST_CUDA,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
Expand All @@ -40,7 +43,9 @@
)
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
NAVI_ARCH,
run_tests,
skipIfRocmArch,
wrapSwapTensorsTest,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
Expand Down Expand Up @@ -93,6 +98,7 @@ def world_size(self) -> int:
return 4

@unittest.skipIf(not TEST_CUDA, "no cuda")
@skipIfRocmArch(NAVI_ARCH) # Supported in future releaes
def test_param_registration_after_forward(self):
"""Tests the parameter registration after forward."""
device = torch.device("cuda", 0)
Expand Down Expand Up @@ -199,6 +205,7 @@ def world_size(self) -> int:

@unittest.skipIf(not TEST_CUDA, "no cuda")
@wrapSwapTensorsTest(True)
@skipIfRocmArch(NAVI_ARCH) # Supported in future releaes
def test_to_float64_after_init(self):
"""Tests that the user can cast the module to float64 after init."""
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
Expand Down Expand Up @@ -317,6 +324,9 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:

@skip_if_lt_x_gpu(2)
@test_compiled_fsdp(compile_compute_on_module=Transformer)
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA"
)
def test_train_parity_multi_group(self):
"""
Tests train parity against DDP when using multiple parameter groups for
Expand Down
24 changes: 12 additions & 12 deletions test/distributed/_composable/test_replicate_with_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ def test_bucketing_coalesced_op(self):
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck()
for i in range(3):
fc.check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
)
fc.check("cpp_fused_")
for i in range(3):
fc.check("torch.ops._c10d_functional.all_reduce_coalesced_.default(")
for i in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default")

Expand All @@ -343,9 +343,9 @@ def test_bucketing_coalesced_op(self):
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck()
for i in range(3):
fc.check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_coalesced_.default("
)
fc.check("cpp_fused_")
for i in range(3):
fc.check("torch.ops._c10d_functional.all_reduce_coalesced_.default(")
for i in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default")

Expand All @@ -372,9 +372,9 @@ def test_bucketing_concat_op(self):
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck()
for i in range(3):
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_.default("
)
fc.check("aten.flatten.using_ints(").check("cpp_fused_")
for i in range(3):
fc.check("torch.ops._c10d_functional.all_reduce_.default(")
for i in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code)
Expand All @@ -384,9 +384,9 @@ def test_bucketing_concat_op(self):
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
fc = FileCheck()
for i in range(3):
fc.check("aten.flatten.using_ints(").check("cpp_fused_").check(
"torch.ops._c10d_functional.all_reduce_.default("
)
fc.check("aten.flatten.using_ints(").check("cpp_fused_")
for i in range(3):
fc.check("torch.ops._c10d_functional.all_reduce_.default(")
for i in range(3):
fc.check("torch.ops._c10d_functional.wait_tensor.default")
fc.run(code)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/_tools/test_sac_ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_optimal_checkpointing_policy_per_module,
sac_milp,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_cuda import TEST_CUDA, PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
run_tests,
skipIfTorchDynamo,
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_sac_ilp_case1(self):

@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@skipIfRocmArch(NAVI_ARCH)
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
def test_sac_ilp_case2(self):
"""
This is a case where the memory budget is not binding, meaning that no
Expand Down
8 changes: 7 additions & 1 deletion test/distributed/elastic/test_control_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
TORCH_WORKER_SERVER_SOCKET,
worker_main,
)
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
from torch.testing._internal.common_utils import (
requires_cuda,
run_tests,
skipIfRocm,
TestCase,
)


class UnixHTTPConnection(HTTPConnection):
Expand Down Expand Up @@ -152,6 +157,7 @@ def test_dump_nccl_trace_pickle_with_json(self) -> None:
)
self.assertEqual(resp.status, 200)

@skipIfRocm # skipped upstream too
def test_tcp(self) -> None:
import requests

Expand Down
4 changes: 4 additions & 0 deletions test/distributed/fsdp/test_fsdp_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
NAVI_ARCH,
parametrize,
run_tests,
skipIfRocmArch,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
Expand Down Expand Up @@ -160,6 +163,7 @@ def test_nested_always_wrap_model(

@skip_if_lt_x_gpu(2)
@parametrize(params, configs, subtest_name)
@skipIfRocmArch(NAVI_ARCH) # Supported in future releases
def test_transformer(
self,
cpu_offload: CPUOffload,
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/fsdp/test_fsdp_hybrid_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import auto, Enum
from functools import partial
from typing import List, Optional, Tuple
import unittest

import torch
import torch.distributed as dist
Expand All @@ -31,6 +32,9 @@
FSDPTest,
TransformerWithSharedParams,
)
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
Expand Down Expand Up @@ -227,6 +231,7 @@ def test_invalid_pg_specification_raises(self):
# resharded after forward.

@skip_if_lt_x_gpu(2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
def test_fsdp_hybrid_shard_basic_setup(self):
"""
Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2:
Expand Down
4 changes: 4 additions & 0 deletions test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
Expand Down Expand Up @@ -236,6 +237,9 @@ def _build_model_and_optim(
return model, optim, ref_model, ref_optim

@skip_if_lt_x_gpu(2)
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA"
)
def test_sharded_grad_scaler_found_inf(self):
self.run_subtests(
{
Expand Down
2 changes: 2 additions & 0 deletions test/distributed/optim/test_zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,8 @@ def closure_sharded(input_tensor=input_tensor):
torch.testing.assert_close(
loss_ddp,
loss_sharded_optim,
atol=1.6e-3,
rtol=3e-6,
msg="Losses differ between local optimizer and ZeRO",
)
self._check_same_model_params(
Expand Down
2 changes: 2 additions & 0 deletions test/distributed/tensor/parallel/test_tp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Transformer,
with_comms,
)
from unittest import skipIf


c10d_functional = torch.ops.c10d_functional
Expand Down Expand Up @@ -414,6 +415,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype):
+ f"{str(dtype).split('.')[-1]}_"
+ f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}",
)

@skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts):
# Sample a subset of `requires_grad` patterns
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_compute_comm_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def func(a):
# above the 2nd matmul.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
Expand Down
14 changes: 6 additions & 8 deletions test/distributed/test_inductor_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,10 +659,10 @@ def func(inp, *, tag, ranks, group_size):
FileCheck()
.check("buf0 = empty_strided")
.check(".run(arg0_1, buf0")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("buf5 = empty_strided")
.check(".run(buf5, 16")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("return (buf0, buf5")
.run(code)
)
Expand Down Expand Up @@ -697,10 +697,10 @@ def func(inp, *, tag, ranks, group_size):
.check("buf0 = empty_strided")
.check("buf5 = empty_strided")
.check(".run(arg0_1, buf0, buf5, 16")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("buf6 = empty_strided")
.check(".run(buf6, 16")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("return (buf0, buf5, buf6")
.run(code)
)
Expand Down Expand Up @@ -1153,9 +1153,8 @@ def func(inp, *, tag, ranks, group_size):
)
.check("buf2 = buf1[0]")
.check("buf3 = buf1[1]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("buf7 = buf0; del buf0 # reuse")
.check(".run(buf7, 16")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("return (buf2, buf6, buf7, buf3")
.run(code)
Expand Down Expand Up @@ -1199,9 +1198,8 @@ def func(inp, *, tag, ranks, group_size):
)
.check("buf2 = buf1[0]")
.check("buf3 = buf1[1]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("buf7 = buf0; del buf0 # reuse")
.check(".run(buf7, 16")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("return (buf2, buf6, buf7, buf3")
.run(code)
Expand Down
Loading