diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index dbb443c6b..c0c852332 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -105,6 +105,13 @@ def _build_mesh_with_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + + # Handle the case where all parallelism dimensions are 1 + # We still need to create a mesh with named dimensions for submesh access + if not dims: + dims = [1] + names = ["world"] + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) # Create all the submesh here to ensure all required process groups are @@ -137,6 +144,11 @@ def _build_mesh_with_ep(self) -> DeviceMesh: if self.etp == 1 and self.tp_enabled: ep_mesh_dim_names.append("tp") + # Ensure dp_cp submesh exists even when all parallelism dimensions are 1 + # This is required for fault tolerance functionality + if not dp_cp_mesh_dim_names and "world" in names: + dp_cp_mesh_dim_names = ["world"] + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") @@ -156,6 +168,13 @@ def _build_mesh_without_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + + # Handle the case where all parallelism dimensions are 1 + # We still need to create a mesh with named dimensions for submesh access + if not dims: + dims = [1] + names = ["world"] + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) # Create all the submesh here to ensure all required process groups are @@ -178,6 +197,12 @@ def _build_mesh_without_ep(self) -> DeviceMesh: dp_shard_cp_mesh_dim_names.append("cp") dp_cp_mesh_dim_names.append("cp") + # Ensure dp_cp submesh exists even when all parallelism dimensions are 1 + # This is required for fault tolerance functionality + if not dp_cp_mesh_dim_names and "world" in names: + dp_cp_mesh_dim_names = ["world"] + + if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") if dp_shard_cp_mesh_dim_names != []: