Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,53 @@ def to_dense(self) -> List[torch.Tensor]:
tensor_list.append(self.values()[offset:next_offset])
return tensor_list

def to_dense_stacked(self) -> torch.Tensor:
"""
Optimized JaggedTensor to dense conversion that provides better performance than to_padded_dense().

Performance optimizations:
1. Length=1 sequences: Zero DtoH transfers (simple reshape)
2. Uniform lengths: Eliminates max_length computation DtoH transfer
3. Variable lengths: Uses to_padded_dense() directly without reconstruction overhead

Returns:
torch.Tensor: Stacked dense tensor equivalent to torch.vstack(jt.to_dense())
"""
lengths = self.lengths()
values = self.values()

# ==================== ULTRA-FAST PATH: Length=1 sequences ============
# This is the most common case in InTrainerSeqStore - all embeddings have length 1
if torch.all(lengths == 1):
# Zero DtoH transfers - pure GPU reshape operation
batch_size = lengths.size(0)
if values.dim() == 1:
return values.view(batch_size, 1)
else:
feature_dim = values.size(-1) if values.dim() > 1 else 1
return values.view(batch_size, feature_dim)

# ==================== FAST PATH: Uniform lengths (non-1) ====================
# Check if all sequences have the same length but not 1
first_length = lengths[0]
if torch.all(lengths == first_length):
# All uniform lengths - we can avoid the expensive max_length computation!
# This is faster than to_padded_dense() because we skip torch.max().item() call
seq_length = (
first_length.item()
) # Only ONE .item() call instead of torch.max().item()

# Use efficient fbgemm call with known exact length (no over-padding)
offsets = self.offsets()
return torch.ops.fbgemm.jagged_to_padded_dense(
values, [offsets], [seq_length], 0.0
) # No trimming needed since we used exact length

# ==================== VARIABLE LENGTH PATH ====================
# For true variable lengths: just delegate to to_padded_dense()
# This is equivalent performance but with the fast paths above for common cases
return self.to_padded_dense()

def to_dense_weights(self) -> Optional[List[torch.Tensor]]:
"""
Constructs a dense-representation of the JT's weights.
Expand Down
94 changes: 94 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,100 @@ def test_to_padded_dense(self) -> None:
expected_t2 = torch.tensor(t2_value).type(torch.int64)
self.assertTrue(torch.equal(t2, expected_t2))

def test_to_dense_stacked(self) -> None:
"""Test to_dense_stacked method for various cases."""
# Test case 1: All sequences have length 1 (ultra-fast path)
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
lengths = torch.tensor([1, 1, 1, 1], dtype=torch.int32)
jt = JaggedTensor(values=values, lengths=lengths)

dense_stacked = jt.to_dense_stacked()
expected = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
self.assertTrue(torch.equal(dense_stacked, expected))

# Test with 2D values - length 1 sequences
values_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
lengths = torch.tensor([1, 1, 1], dtype=torch.int32)
jt_2d = JaggedTensor(values=values_2d, lengths=lengths)

dense_stacked_2d = jt_2d.to_dense_stacked()
expected_2d = values_2d.view(3, 2) # Should match original shape
self.assertTrue(torch.equal(dense_stacked_2d, expected_2d))

# Test case 2: All sequences have same uniform length (fast path)
values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
lengths = torch.tensor([2, 2, 2], dtype=torch.int32)
jt = JaggedTensor(values=values, lengths=lengths)

dense_stacked = jt.to_dense_stacked()
# Compare with to_padded_dense to ensure consistency
expected_padded = jt.to_padded_dense()
self.assertTrue(torch.equal(dense_stacked, expected_padded))

# Test case 3: Variable lengths (fallback path)
values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
lengths = torch.tensor([2, 0, 3, 1, 2], dtype=torch.int32)
jt = JaggedTensor(values=values, lengths=lengths)

dense_stacked = jt.to_dense_stacked()
# Should be equivalent to to_padded_dense for variable lengths
expected_padded = jt.to_padded_dense()
self.assertTrue(torch.equal(dense_stacked, expected_padded))

# Test case 4: Empty tensor
empty_values = torch.tensor([], dtype=torch.float32)
empty_lengths = torch.tensor([], dtype=torch.int32)
jt_empty = JaggedTensor(values=empty_values, lengths=empty_lengths)

dense_stacked_empty = jt_empty.to_dense_stacked()
# Empty case should result in empty tensor with correct shape
self.assertEqual(dense_stacked_empty.numel(), 0)

# Test case 5: Single batch with multiple elements
values = torch.tensor([10.0, 20.0, 30.0])
lengths = torch.tensor([3], dtype=torch.int32)
jt_single = JaggedTensor(values=values, lengths=lengths)

dense_stacked_single = jt_single.to_dense_stacked()
expected_single = torch.tensor([[10.0, 20.0, 30.0]])
self.assertTrue(torch.equal(dense_stacked_single, expected_single))

# Test case 6: Mix of empty and non-empty sequences (variable length)
values = torch.tensor([1.0, 2.0, 3.0])
lengths = torch.tensor([2, 0, 1], dtype=torch.int32)
jt_mixed = JaggedTensor(values=values, lengths=lengths)

dense_stacked_mixed = jt_mixed.to_dense_stacked()
expected_mixed = jt_mixed.to_padded_dense()
self.assertTrue(torch.equal(dense_stacked_mixed, expected_mixed))

# Test case 7: Performance comparison - ensure to_dense_stacked matches to_padded_dense
# for correctness on various shapes
values = torch.randn(100)
lengths = torch.randint(1, 10, (20,), dtype=torch.int32)
# Adjust lengths to match values size
total_needed = values.size(0)
current_sum = int(lengths.sum().item())
if current_sum != total_needed:
# Adjust the last length to make the sum correct
lengths[-1] = lengths[-1] + (total_needed - current_sum)

jt_large = JaggedTensor(values=values, lengths=lengths)

dense_stacked_large = jt_large.to_dense_stacked()
expected_padded_large = jt_large.to_padded_dense()
self.assertTrue(torch.equal(dense_stacked_large, expected_padded_large))

# Test case 8: Different data types
values_int = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int64)
lengths_int = torch.tensor([2, 2, 2], dtype=torch.int32)
jt_int = JaggedTensor(values=values_int, lengths=lengths_int)

dense_stacked_int = jt_int.to_dense_stacked()
self.assertEqual(dense_stacked_int.dtype, torch.int64)
expected_int = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int64)
self.assertTrue(torch.equal(dense_stacked_int, expected_int))

def test_to_padded_dense_weights(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).type(
torch.float64
Expand Down
Loading