From a735ab0d4836fa8c835f2950b30972679e6a6243 Mon Sep 17 00:00:00 2001 From: James Dong Date: Tue, 6 May 2025 19:27:23 -0700 Subject: [PATCH 1/3] Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. (#2949) Summary: Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc. Differential Revision: D74273083 --- torchrec/sparse/jagged_tensor.py | 5 +++++ torchrec/sparse/tests/test_keyed_jagged_tensor.py | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 2bbe09149..dacec0407 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1097,11 +1097,15 @@ def _maybe_compute_stride_kjt( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: + # For VBE KJT, use inverse_indices for the batch size of the EBC output KeyedTensor. + if inverse_indices is not None and inverse_indices[1].numel() > 0: + return inverse_indices[1].shape[-1] stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) @@ -2165,6 +2169,7 @@ def stride(self) -> int: self._lengths, self._offsets, self._stride_per_key_per_rank, + self._inverse_indices, ) self._stride = stride return stride diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index 1636a06bd..ed8a9eac9 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None: lengths=torch.tensor([], device=torch.device("meta")), ) + def test_vbe_kjt_stride(self) -> None: + inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) + kjt = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank=[[2], [1]], + inverse_indices=(["f1", "f2"], inverse_indices), + ) + + self.assertEqual(kjt.stride(), inverse_indices.shape[1]) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: From 905c73fca28fe420e5e20f3da7891f199432044f Mon Sep 17 00:00:00 2001 From: James Dong Date: Tue, 6 May 2025 19:27:23 -0700 Subject: [PATCH 2/3] Create a new tensor arg for stride_per_key_per_rank to facilitate torch.export (#2950) Summary: # Context * Currently torchrec IR serializer can't handle variable batch KJT use case. * To support VBE KJT, the `stride_per_key_per_rank` field needs to be flattened as a variable in the pytree flatten spec for a VBE KJT to be unflattened correctly by`torch.export`. * Currently `stride_per_key_per_rank` is a List. To flatten the `stride_per_key_per_rank` info as a variable we have to add a new tensor field for it. # Ref Differential Revision: D74207283 --- torchrec/sparse/jagged_tensor.py | 52 ++++++++++++++++--- .../sparse/tests/test_keyed_jagged_tensor.py | 22 ++++++-- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index dacec0407..17d617eeb 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1779,6 +1779,7 @@ def __init__( index_per_key: Optional[Dict[str, int]] = None, jt_dict: Optional[Dict[str, JaggedTensor]] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + stride_per_key_per_rank_tensor: Optional[torch.Tensor] = None, ) -> None: """ This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible. @@ -1795,6 +1796,11 @@ def __init__( self._stride_per_key_per_rank: Optional[List[List[int]]] = ( stride_per_key_per_rank ) + + self._stride_per_key_per_rank_tensor: torch.Tensor = torch.empty(0) + if stride_per_key_per_rank_tensor is not None: + self._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor + self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key @@ -2184,7 +2190,7 @@ def stride_per_key(self) -> List[int]: """ stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, - self._stride_per_key_per_rank, + self._stride_per_key_per_rank_optional, self.stride(), self._keys, ) @@ -2199,7 +2205,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]: List[List[int]]: stride per key per rank of the KeyedJaggedTensor. """ stride_per_key_per_rank = self._stride_per_key_per_rank - return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + + if stride_per_key_per_rank is not None: + return stride_per_key_per_rank + + if self._stride_per_key_per_rank_tensor.numel() > 0: + return self._stride_per_key_per_rank_tensor.tolist() + + return [] + + @property + def _stride_per_key_per_rank_optional(self) -> Optional[List[List[int]]]: + if self._stride_per_key_per_rank is not None: + return self._stride_per_key_per_rank + + if self._stride_per_key_per_rank_tensor.numel() > 0: + stride_per_key_per_rank: List[List[int]] = ( + self._stride_per_key_per_rank_tensor.tolist() + ) + return stride_per_key_per_rank + + return None def variable_stride_per_key(self) -> bool: """ @@ -2210,7 +2236,7 @@ def variable_stride_per_key(self) -> bool: """ if self._variable_stride_per_key is not None: return self._variable_stride_per_key - return self._stride_per_key_per_rank is not None + return self._stride_per_key_per_rank_optional is not None def inverse_indices(self) -> Tuple[List[str], torch.Tensor]: """ @@ -2375,6 +2401,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=self._index_per_key, jt_dict=self._jt_dict, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) elif segment == 0: @@ -2411,6 +2438,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) else: @@ -2457,6 +2485,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) else: @@ -2493,6 +2522,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) start = end @@ -2599,12 +2629,15 @@ def permute( index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) return kjt def flatten_lengths(self) -> "KeyedJaggedTensor": stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) return KeyedJaggedTensor( keys=self._keys, @@ -2621,6 +2654,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor": index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) def __getitem__(self, key: str) -> JaggedTensor: @@ -2760,7 +2794,9 @@ def to( lengths = self._lengths offsets = self._offsets stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) length_per_key = self._length_per_key lengths_offset_per_key = self._lengths_offset_per_key @@ -2805,6 +2841,7 @@ def to( index_per_key=index_per_key, jt_dict=jt_dict, inverse_indices=inverse_indices, + stride_per_key_per_rank_tensor=None, ) def __str__(self) -> str: @@ -2836,7 +2873,9 @@ def pin_memory(self) -> "KeyedJaggedTensor": lengths = self._lengths offsets = self._offsets stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) inverse_indices = self._inverse_indices if inverse_indices is not None: @@ -2857,6 +2896,7 @@ def pin_memory(self) -> "KeyedJaggedTensor": index_per_key=self._index_per_key, jt_dict=None, inverse_indices=inverse_indices, + stride_per_key_per_rank_tensor=None, ) def dist_labels(self) -> List[str]: diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index ed8a9eac9..581e27208 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1018,16 +1018,32 @@ def test_meta_device_compatibility(self) -> None: ) def test_vbe_kjt_stride(self) -> None: + stride_per_key_per_rank = [[2], [1]] inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) - kjt = KeyedJaggedTensor( + kjt_1 = KeyedJaggedTensor( keys=["f1", "f2", "f3"], values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), lengths=torch.tensor([3, 3, 2]), - stride_per_key_per_rank=[[2], [1]], + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=(["f1", "f2"], inverse_indices), + ) + kjt_2 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank_tensor=torch.tensor(stride_per_key_per_rank), inverse_indices=(["f1", "f2"], inverse_indices), ) - self.assertEqual(kjt.stride(), inverse_indices.shape[1]) + self.assertEqual(kjt_1.stride(), inverse_indices.shape[1]) + self.assertEqual(kjt_1.stride_per_key_per_rank(), stride_per_key_per_rank) + self.assertEqual( + kjt_1._stride_per_key_per_rank_optional, stride_per_key_per_rank + ) + self.assertEqual(kjt_2.stride_per_key_per_rank(), stride_per_key_per_rank) + self.assertEqual( + kjt_2._stride_per_key_per_rank_optional, stride_per_key_per_rank + ) class TestKeyedJaggedTensorScripting(unittest.TestCase): From 20069fd1eda28c2f7cddcfdff892aee0ca22b395 Mon Sep 17 00:00:00 2001 From: James Dong Date: Tue, 6 May 2025 19:27:23 -0700 Subject: [PATCH 3/3] Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Differential Revision: D74295924 --- torchrec/ir/tests/test_serializer.py | 80 +++++++++++++++------------ torchrec/modules/embedding_modules.py | 2 +- torchrec/sparse/jagged_tensor.py | 15 ++++- 3 files changed, 59 insertions(+), 38 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 31af19ec8..3e108675d 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -206,8 +206,14 @@ def forward( num_embeddings=10, feature_names=["f2"], ) + config3 = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) ebc = EmbeddingBagCollection( - tables=[config1, config2], + tables=[config1, config2, config3], is_weighted=False, ) @@ -292,24 +298,37 @@ def test_serialize_deserialize_ebc(self) -> None: self.assertEqual(deserialized.shape, orginal.shape) self.assertTrue(torch.allclose(deserialized, orginal)) - @unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.") def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: model = self.generate_model_for_vbe_kjt() - id_list_features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), - lengths=torch.tensor([3, 3, 2]), - stride_per_key_per_rank=[[2], [1]], - inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])), + kjt_1 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + lengths=torch.tensor([1, 2, 3, 2, 1, 1]), + stride_per_key_per_rank_tensor=torch.tensor([[3], [2], [1]]), + inverse_indices=( + ["f1", "f2", "f3"], + torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]), + ), + ) + kjt_2 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + lengths=torch.tensor([1, 2, 3, 2, 1, 1]), + stride_per_key_per_rank_tensor=torch.tensor([[1], [2], [3]]), + inverse_indices=( + ["f1", "f2", "f3"], + torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]), + ), ) - eager_out = model(id_list_features) + eager_out = model(kjt_1) + eager_out_2 = model(kjt_2) # Serialize EBC model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) ep = torch.export.export( model, - (id_list_features,), + (kjt_1,), {}, strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP @@ -317,17 +336,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: ) # Run forward on ExportedProgram - ep_output = ep.module()(id_list_features) + ep_output = ep.module()(kjt_1) + ep_output_2 = ep.module()(kjt_2) + self.assertEqual(len(ep_output), len(kjt_1.keys())) + self.assertEqual(len(ep_output_2), len(kjt_2.keys())) for i, tensor in enumerate(ep_output): - self.assertEqual(eager_out[i].shape, tensor.shape) + self.assertEqual(eager_out[i].shape[1], tensor.shape[1]) + for i, tensor in enumerate(ep_output_2): + self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1]) # Deserialize EBC unflatten_ep = torch.export.unflatten(ep) deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) # check EBC config - for i in range(5): + for i in range(1): ebc_name = f"ebc{i + 1}" self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection @@ -342,36 +366,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None: self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) self.assertEqual(deserialized.feature_names, orginal.feature_names) - # check FPEBC config - for i in range(2): - fpebc_name = f"fpebc{i + 1}" - assert isinstance( - getattr(deserialized_model, fpebc_name), - FeatureProcessedEmbeddingBagCollection, - ) - - for deserialized, orginal in zip( - getattr( - deserialized_model, fpebc_name - )._embedding_bag_collection.embedding_bag_configs(), - getattr( - model, fpebc_name - )._embedding_bag_collection.embedding_bag_configs(), - ): - self.assertEqual(deserialized.name, orginal.name) - self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) - self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) - self.assertEqual(deserialized.feature_names, orginal.feature_names) - # Run forward on deserialized model and compare the output deserialized_model.load_state_dict(model.state_dict()) - deserialized_out = deserialized_model(id_list_features) + deserialized_out = deserialized_model(kjt_1) self.assertEqual(len(deserialized_out), len(eager_out)) for deserialized, orginal in zip(deserialized_out, eager_out): self.assertEqual(deserialized.shape, orginal.shape) self.assertTrue(torch.allclose(deserialized, orginal)) + deserialized_out_2 = deserialized_model(kjt_2) + + self.assertEqual(len(deserialized_out_2), len(eager_out_2)) + for deserialized, orginal in zip(deserialized_out_2, eager_out_2): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) + def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None: model = self.generate_model() feature1 = KeyedJaggedTensor.from_offsets_sync( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index d110fd57f..77c34fe7e 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -27,7 +27,7 @@ def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], feature_names: List[str], ) -> torch.Tensor: - if inverse_indices is None: + if inverse_indices is None or inverse_indices[1].numel() == 0: return torch.empty(0) index_per_name = {name: i for i, name in enumerate(inverse_indices[0])} index = torch.tensor( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 17d617eeb..3c8095cfc 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1760,6 +1760,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): "_weights", "_lengths", "_offsets", + "_stride_per_key_per_rank_tensor", + "_inverse_indices_tensor", ] def __init__( @@ -1810,6 +1812,9 @@ def __init__( self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( inverse_indices ) + self._inverse_indices_tensor: torch.Tensor = torch.empty(0) + if inverse_indices is not None: + self._inverse_indices_tensor = inverse_indices[1] # legacy attribute, for backward compatabilibity self._variable_stride_per_key: Optional[bool] = None @@ -3092,9 +3097,15 @@ def _kjt_flatten_with_keys( def _kjt_unflatten( - values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys + values: List[Optional[torch.Tensor]], + context: List[str], ) -> KeyedJaggedTensor: - return KeyedJaggedTensor(context, *values) + return KeyedJaggedTensor( + context, + *values[:-2], + stride_per_key_per_rank_tensor=values[-2], + inverse_indices=(context, values[-1]) if values[-1] is not None else None, + ) def _kjt_flatten_spec(