diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 4c66511ef..18d1871ae 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -173,9 +173,33 @@ def forward(self, tensors: List[torch.Tensor], cat_dim: int) -> torch.Tensor: Here we assume input tensors are: [TBE_output_0, ..., TBE_output_(n-1)] """ - B = tensors[0].size(1 - cat_dim) + # Handle empty shards case (can happen in column-wise sharding) + if not tensors or len(tensors) == 0: + # Return empty tensor if no tensors provided + return torch.empty(0, 0, dtype=torch.float, device=self.current_device) + + # Check if we are in TorchScript mode first to avoid global variable access issues + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # In TorchScript or JIT tracing mode, use all tensors and let FBGEMM handle empties + tensors_to_use = tensors + else: + if torch.fx._symbolic_trace.is_fx_tracing(): + # During FX tracing, include all tensors to avoid control flow issues + tensors_to_use = tensors + else: + # Normal execution: filter out empty tensors + non_empty_tensors = [] + + for t in tensors: + if t.numel() > 0 and t.size(cat_dim) > 0: + non_empty_tensors.append(t) + + tensors_to_use = non_empty_tensors if non_empty_tensors else tensors + + # Use the first tensor to determine batch size + B = tensors_to_use[0].size(1 - cat_dim) return torch.ops.fbgemm.merge_pooled_embeddings( - tensors, + tensors_to_use, B, self.current_device, cat_dim, diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 4c9a29810..ebf8ecddb 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -371,6 +371,26 @@ def _emb_module_forward( lengths_or_offsets: torch.Tensor, weights: Optional[torch.Tensor], ) -> torch.Tensor: + # Check if total embedding dimension is 0 (can happen in column-wise sharding) + total_D = sum(table.local_cols for table in self._config.embedding_tables) + + if total_D == 0: + # For empty shards, return tensor with correct batch size but 0 embedding dimension + # Use tensor operations that are FX symbolic tracing compatible + if self.lengths_to_tbe: + # For lengths format, batch size equals lengths tensor size + # Create [B, 0] tensor using zeros_like and slicing + dummy = torch.zeros_like(lengths_or_offsets, dtype=torch.float) + return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor + else: + # For offsets format, batch size is one less than offset size + # Use tensor slicing to create batch dimension + batch_tensor = lengths_or_offsets[ + :-1 + ] # Remove last element to get batch size + dummy = torch.zeros_like(batch_tensor, dtype=torch.float) + return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor + kwargs = {"indices": indices} if self.lengths_to_tbe: @@ -600,6 +620,18 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: else: values, offsets, _ = _unwrap_kjt(features) + # Check if total embedding dimension is 0 + total_D = sum(table.local_cols for table in self._config.embedding_tables) + + if total_D == 0: + # For empty shards, return tensor with correct batch size but 0 embedding dimension + # Use tensor operations that are FX symbolic tracing compatible + # For offsets format, batch size is one less than offset size + # Use tensor slicing to create batch dimension + batch_tensor = offsets[:-1] # Remove last element to get batch size + dummy = torch.zeros_like(batch_tensor, dtype=torch.float) + return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor + if self._emb_module_registered: return self.emb_module( indices=values, diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 3f8608899..3205be116 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -571,9 +571,9 @@ def test_cw( def test_uneven_cw(self, weight_dtype: torch.dtype, device_type: str) -> None: num_embeddings = 64 emb_dim = 512 - dim_1 = 63 + dim_1 = 0 dim_2 = 128 - dim_3 = 65 + dim_3 = 128 dim_4 = 256 local_size = 4 world_size = 4