Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")

# Handle the case where all parallelism dimensions are 1
# We still need to create a mesh with named dimensions for submesh access
if not dims:
dims = [1]
names = ["world"]

mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
Expand Down Expand Up @@ -137,6 +144,11 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
if self.etp == 1 and self.tp_enabled:
ep_mesh_dim_names.append("tp")

# Ensure dp_cp submesh exists even when all parallelism dimensions are 1
# This is required for fault tolerance functionality
if not dp_cp_mesh_dim_names and "world" in names:
dp_cp_mesh_dim_names = ["world"]

mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
Expand All @@ -156,6 +168,13 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")

# Handle the case where all parallelism dimensions are 1
# We still need to create a mesh with named dimensions for submesh access
if not dims:
dims = [1]
names = ["world"]

mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
Expand All @@ -178,6 +197,12 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")

# Ensure dp_cp submesh exists even when all parallelism dimensions are 1
# This is required for fault tolerance functionality
if not dp_cp_mesh_dim_names and "world" in names:
dp_cp_mesh_dim_names = ["world"]


if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_shard_cp_mesh_dim_names != []:
Expand Down