Skip to content

Chao/xccl ut #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b13b64d
enable XPU tests
zhangxiaoli73 Jan 9, 2025
314b78e
update
Chao1Han Jan 17, 2025
88ed5d2
enable some XPU distributed tests
zhangxiaoli73 Jan 17, 2025
b89bedf
add more changes for XPU
zhangxiaoli73 Jan 22, 2025
3d97844
correct set index
Chao1Han Jan 23, 2025
9d89a1f
enable xpu tests
zhangxiaoli73 Jan 23, 2025
fd57737
enable all tests on XPU
zhangxiaoli73 Jan 24, 2025
c679690
enable TP
zhangxiaoli73 Jan 24, 2025
8c306ce
add comm test
Chao1Han Jan 24, 2025
e601f57
enable FSDP2
zhangxiaoli73 Feb 8, 2025
d02fb2a
add pileline
Chao1Han Feb 8, 2025
0d682f6
update ddp
Chao1Han Feb 10, 2025
0f99a76
update
Chao1Han Feb 10, 2025
b794de0
update comm ut
Chao1Han Feb 10, 2025
c776d8a
fix some changes
zhangxiaoli73 Feb 10, 2025
baed9bd
cuda to xpu
zhangxiaoli73 Feb 10, 2025
cc24e89
fix fake pg and skip gloo test
Chao1Han Feb 11, 2025
a5f1ca3
enable tests
zhangxiaoli73 Feb 13, 2025
954d63d
make changes
zhangxiaoli73 Feb 13, 2025
7451b9d
fix fsdp
Chao1Han Feb 13, 2025
ce0ebf0
change
Chao1Han Feb 14, 2025
52c7074
change
Chao1Han Feb 14, 2025
213177f
change
Chao1Han Feb 14, 2025
7763c87
try to skip FSDPTestMultiThread for xpu
Chao1Han Feb 19, 2025
07ac0d1
update
Chao1Han Feb 20, 2025
5bb00c9
Merge remote-tracking branch 'upstream/main' into chao/xccl_ut
Chao1Han Feb 20, 2025
f78a563
update
Chao1Han Feb 25, 2025
96e78b0
update dtensor
Chao1Han Feb 27, 2025
55e5ddc
update xpu commit
Chao1Han Mar 14, 2025
b8fbdc1
update
Chao1Han Mar 19, 2025
64dae28
update fsdp test
Chao1Han Mar 27, 2025
c4b22e6
update
Chao1Han Mar 28, 2025
fe87579
update
Chao1Han Mar 28, 2025
da5ea88
hardcode world_size to 4
Chao1Han Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions test/distributed/_composable/fsdp/test_fully_shard_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
FSDPTestMultiThread,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
Expand All @@ -31,7 +31,7 @@
class TestFullyShardAutograd(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.xpu.device_count())

def _reduce_1d_partial_grads(
self, module: nn.Module, group: Optional[dist.ProcessGroup] = None
Expand All @@ -58,7 +58,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
local_batch_size = 2
global_batch_size, dim = (self.world_size * local_batch_size, 24)
model = DoubleLinear(dim=dim, use_second_linear=True)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).xpu()
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
Expand All @@ -68,7 +68,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
for iter_idx in range(10):
# Use all forward outputs in the loss/backward for the first half
# of the iterations and only the 1st forward output for the rest
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device="xpu")
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
Expand Down Expand Up @@ -104,7 +104,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
local_batch_size, dim = (2, 24)
global_batch_size = self.world_size * local_batch_size
model = DoubleLinear(dim=dim, use_second_linear=False)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).xpu()
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
fully_shard(model.lin2, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
Expand All @@ -113,7 +113,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):

torch.manual_seed(1) # same on all ranks
for iter_idx in range(10):
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device="xpu")
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(self, x: torch.Tensor):
Module(dim),
FromContainerType(container_type),
)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).xpu()
for module in model:
fully_shard(module)
fully_shard(model)
Expand All @@ -223,7 +223,7 @@ def forward(self, x: torch.Tensor):

torch.manual_seed(1) # same on all ranks
for iter_idx in range(10):
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device="xpu")
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
Expand All @@ -245,7 +245,6 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
def world_size(self) -> int:
return 2

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_post_acc_grad_hook_runs(self):
param_name_to_hook_count = collections.defaultdict(int)

Expand All @@ -260,7 +259,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
param_hook = functools.partial(hook, param_name)
param.register_post_accumulate_grad_hook(param_hook)

inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device="xpu")
model(inp).sum().backward()
param_names = {param_name for param_name, _ in model.named_parameters()}
self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
Expand All @@ -271,7 +270,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.xpu.device_count(), 2)

@skip_if_lt_x_gpu(2)
def test_post_acc_grad_hook_optim_parity(self):
Expand All @@ -283,7 +282,7 @@ def test_post_acc_grad_hook_optim_parity(self):
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)

ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).xpu()
for module in itertools.chain(ref_model.layers, [ref_model]):
fully_shard(module)
optim_kwargs = {"lr": 1e-2, "foreach": False}
Expand Down Expand Up @@ -312,7 +311,7 @@ def optim_hook(param: nn.Parameter) -> None:
param.register_post_accumulate_grad_hook(optim_hook)

torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="xpu")
for _ in range(10):
ref_loss = ref_model(inp).sum()
ref_loss.backward()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def _test_clip_grad_norm(
dp_mesh: Optional[DeviceMesh] = None,
):
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
dp_mesh = dp_mesh or init_device_mesh("xpu", (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
for _ in range(10):
for iter_idx in range(10):
ref_optim.zero_grad()
ref_model(inp).sum().backward()
optim.zero_grad()
Expand Down Expand Up @@ -91,22 +91,22 @@ def _test_clip_grad_norm(
class TestClipGradNormWorldSize2(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.xpu.device_count(), 2)

@skip_if_lt_x_gpu(2)
def test_clip_grad_norm_1d(self):
for norm_type in (2, 1, float("inf")):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_model = replicate(copy.deepcopy(model).xpu())
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="xpu")
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, inp
)
Expand All @@ -115,14 +115,14 @@ def test_clip_grad_norm_1d(self):
class TestClipGradNormWorldSize4(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
return min(torch.xpu.device_count(), 4)

@skip_if_lt_x_gpu(4)
def test_clip_grad_norm_2d(self):
for norm_type in (2, 1, 3, float("inf")):
dp_size = 2
global_mesh = init_device_mesh(
"cuda",
"xpu",
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
Expand All @@ -132,7 +132,7 @@ def test_clip_grad_norm_2d(self):
# has some more significant numeric differences from the TP
model = MLPStack(16, with_seq_parallel=True)
ref_model = replicate(
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
copy.deepcopy(model).xpu(), process_group=dp_mesh.get_group()
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
model.parallelize(
Expand All @@ -142,7 +142,7 @@ def test_clip_grad_norm_2d(self):
reshard_after_forward=True,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randn(2, 16, device="cuda")
inp = torch.randn(2, 16, device="xpu")
self._test_clip_grad_norm(
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
)
Expand Down
Loading