diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index f03d789c7..93c1fd295 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -31,14 +31,11 @@ get_origin, List, Optional, - Set, Tuple, TypeVar, Union, ) -import click - import torch import yaml from torch import multiprocessing as mp @@ -60,11 +57,7 @@ from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv from torchrec.fx import symbolic_trace from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig -from torchrec.quant.embedding_modules import ( - EmbeddingBagCollection as QuantEmbeddingBagCollection, - EmbeddingCollection as QuantEmbeddingCollection, -) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import get_free_port logger: logging.Logger = logging.getLogger() @@ -209,115 +202,56 @@ def _mem_percentile( return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation) -class ECWrapper(torch.nn.Module): +class ModuleWrapper(torch.nn.Module): """ - Wrapper Module for benchmarking EC Modules + A wrapper for nn.modules that allows them to accept inputs + of type KeyedJaggedTensor or ModelInput and forwards them to the + underlying module. This wrapper is necessary to provide compatibility + with FX tracing. Args: - module: module to benchmark - - Call Args: - input: KeyedJaggedTensor KJT input to module - - Returns: - output: KT output from module + module: The torch.nn.Module to be wrapped. Example: - e1_config = EmbeddingConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - e2_config = EmbeddingConfig( - name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + import torch + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + from torchrec.distributed.benchmark.benchmark_utils import ModuleWrapper + + # Create a simple module + module = torch.nn.Linear(10, 5) + wrapped_module = ModuleWrapper(module) + + # Create a KeyedJaggedTensor input + values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + weights = None + lengths = torch.tensor([2, 0, 1, 1, 3, 1]) + offsets = torch.tensor([0, 2, 2, 3, 4, 7, 8]) + keys = ["F1", "F2", "F3"] + kjt = KeyedJaggedTensor( + values=values, + weights=weights, + lengths=lengths, + offsets=offsets, + keys=keys, ) - ec = EmbeddingCollection(tables=[e1_config, e2_config]) - - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - - ec.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8 - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), - ) - - qec = QuantEmbeddingCollection.from_float(ecc) - - wrapped_module = ECWrapper(qec) - quantized_embeddings = wrapped_module(features) + # Forward the input through the wrapped module + output = wrapped_module(kjt) """ def __init__(self, module: torch.nn.Module) -> None: super().__init__() self._module = module - def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: - """ - Args: - input (KeyedJaggedTensor): KJT of form [F X B X L]. - - Returns: - Dict[str, JaggedTensor] + def forward(self, input: Union[KeyedJaggedTensor, ModelInput]) -> Any: # pyre-ignore[3] """ - return self._module.forward(input) - - -class EBCWrapper(torch.nn.Module): - """ - Wrapper Module for benchmarking Modules - - Args: - module: module to benchmark - - Call Args: - input: KeyedJaggedTensor KJT input to module - - Returns: - output: KT output from module - - Example: - table_0 = EmbeddingBagConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - table_1 = EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] - ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + Forward pass of the wrapped module. - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - - ebc.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8 - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), - ) - - qebc = QuantEmbeddingBagCollection.from_float(ebc) - - wrapped_module = EBCWrapper(qebc) - quantized_embeddings = wrapped_module(features) - """ - - def __init__(self, module: torch.nn.Module) -> None: - super().__init__() - self._module = module - - def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: - """ Args: - input (KeyedJaggedTensor): KJT of form [F X B X L]. + input: Input of type KeyedJaggedTensor or ModelInput to be forwarded to the underlying module. Returns: - KeyedTensor + The output from the underlying module's forward pass. """ return self._module.forward(input) @@ -369,11 +303,31 @@ def get_inputs( batch_size: int, world_size: int, num_inputs: int, + num_float_features: int, train: bool, pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, -) -> List[List[KeyedJaggedTensor]]: - inputs_batch: List[List[KeyedJaggedTensor]] = [] + only_kjt: bool = False, +) -> Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]]: + """ + Generate inputs for benchmarking. + + Args: + tables: List of embedding tables configurations + batch_size: Batch size for generated inputs + world_size: Number of ranks/processes + num_inputs: Number of input batches to generate + num_float_features: Number of float features + train: Whether inputs are for training + pooling_configs: Optional pooling factors for tables + variable_batch_embeddings: Whether to use variable batch size + only_kjt: If True, return KeyedJaggedTensor instead of ModelInput + + Returns: + If only_kjt is False: List of lists of ModelInput objects + If only_kjt is True: List of lists of KeyedJaggedTensor objects + """ + inputs_batch = [] if variable_batch_embeddings and not train: raise RuntimeError("Variable batch size is only supported in training mode") @@ -383,14 +337,14 @@ def get_inputs( _, model_input_by_rank = ModelInput.generate_variable_batch_input( average_batch_size=batch_size, world_size=world_size, - num_float_features=0, + num_float_features=num_float_features, tables=tables, ) else: _, model_input_by_rank = ModelInput.generate( batch_size=batch_size, world_size=world_size, - num_float_features=0, + num_float_features=num_float_features, tables=tables, weighted_tables=[], tables_pooling=pooling_configs, @@ -399,21 +353,31 @@ def get_inputs( ) if train: - sparse_features_by_rank = [ - model_input.idlist_features - for model_input in model_input_by_rank - if isinstance(model_input.idlist_features, KeyedJaggedTensor) - ] - inputs_batch.append(sparse_features_by_rank) + inputs_batch.append( + [ + model_input + for model_input in model_input_by_rank + if isinstance(model_input.idlist_features, KeyedJaggedTensor) + or not only_kjt + ] + ) else: - sparse_features = model_input_by_rank[0].idlist_features - assert isinstance(sparse_features, KeyedJaggedTensor) - inputs_batch.append([sparse_features]) + assert ( + isinstance(model_input_by_rank[0].idlist_features, KeyedJaggedTensor) + or not only_kjt + ) + inputs_batch.append([model_input_by_rank[0]]) + + # If only_kjt is True, extract idlist_features from ModelInput objects + if only_kjt: + inputs_batch = [ + [model_input.idlist_features for model_input in batch] + for batch in inputs_batch + ] # Transpose if train, as inputs_by_rank is currently in [B X R] format inputs_by_rank = [ - [sparse_features for sparse_features in sparse_features_rank] - for sparse_features_rank in zip(*inputs_batch) + list(model_inputs_rank) for model_inputs_rank in zip(*inputs_batch) ] return inputs_by_rank @@ -630,7 +594,7 @@ def init_argparse_and_args() -> argparse.Namespace: def transform_module( module: torch.nn.Module, device: torch.device, - inputs: List[KeyedJaggedTensor], + inputs: Union[List[ModelInput], List[KeyedJaggedTensor]], sharder: ModuleSharder[T], sharding_type: ShardingType, compile_mode: CompileMode, @@ -876,9 +840,13 @@ def _trace_handler(prof: torch.profiler.profile) -> None: def benchmark( name: str, model: torch.nn.Module, - warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], - bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], - prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + warmup_inputs: Union[ + List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]] + ], + bench_inputs: Union[ + List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]] + ], + prof_inputs: Union[List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]]], world_size: int, output_dir: str, num_benchmarks: int, @@ -994,9 +962,9 @@ def init_module_and_run_benchmark( compile_mode: CompileMode, world_size: int, batch_size: int, - warmup_inputs: List[List[KeyedJaggedTensor]], - bench_inputs: List[List[KeyedJaggedTensor]], - prof_inputs: List[List[KeyedJaggedTensor]], + warmup_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]], + bench_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]], + prof_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]], tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], output_dir: str, num_benchmarks: int, @@ -1056,7 +1024,7 @@ def init_module_and_run_benchmark( module = transform_module( module=module, device=device, - inputs=warmup_inputs_cuda, + inputs=warmup_inputs_cuda, # pyre-ignore[6] sharder=sharder, sharding_type=sharding_type, compile_mode=compile_mode, @@ -1075,9 +1043,9 @@ def init_module_and_run_benchmark( res = benchmark( name, module, - warmup_inputs_cuda, - bench_inputs_cuda, - prof_inputs_cuda, + warmup_inputs_cuda, # pyre-ignore[6] + bench_inputs_cuda, # pyre-ignore[6] + prof_inputs_cuda, # pyre-ignore[6] world_size=world_size, output_dir=output_dir, num_benchmarks=num_benchmarks, @@ -1167,6 +1135,7 @@ def benchmark_module( warmup_iters: int = 20, bench_iters: int = 500, prof_iters: int = 20, + num_float_features: int = 0, batch_size: int = 2048, world_size: int = 2, num_benchmarks: int = 5, @@ -1177,6 +1146,7 @@ def benchmark_module( pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, device_type: str = "cuda", + train: bool = True, ) -> List[BenchmarkResult]: """ Args: @@ -1211,19 +1181,10 @@ def benchmark_module( assert ( num_benchmarks > 2 ), "num_benchmarks needs to be greater than 2 for statistical analysis" - if isinstance(module, QuantEmbeddingBagCollection) or isinstance( - module, QuantEmbeddingCollection - ): - train = False - else: - train = True benchmark_results: List[BenchmarkResult] = [] - if isinstance(tables[0], EmbeddingBagConfig): - wrapped_module = EBCWrapper(module) - else: - wrapped_module = ECWrapper(module) + wrapped_module = ModuleWrapper(module) num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters inputs = get_inputs( @@ -1231,9 +1192,11 @@ def benchmark_module( batch_size, world_size, num_inputs_to_gen, + num_float_features, train, pooling_configs, variable_batch_embeddings, + only_kjt=True, ) warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs] @@ -1288,9 +1251,9 @@ def benchmark_module( compile_mode=compile_mode, world_size=world_size, batch_size=batch_size, - warmup_inputs=warmup_inputs, - bench_inputs=bench_inputs, - prof_inputs=prof_inputs, + warmup_inputs=warmup_inputs, # pyre-ignore[6] + bench_inputs=bench_inputs, # pyre-ignore[6] + prof_inputs=prof_inputs, # pyre-ignore[6] tables=tables, num_benchmarks=num_benchmarks, output_dir=output_dir,