Skip to content

Generalize benchmark_module function #3223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 101 additions & 138 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@
get_origin,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

import click

import torch
import yaml
from torch import multiprocessing as mp
Expand All @@ -60,11 +57,7 @@
from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv
from torchrec.fx import symbolic_trace
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.test_utils import get_free_port

logger: logging.Logger = logging.getLogger()
Expand Down Expand Up @@ -209,115 +202,56 @@ def _mem_percentile(
return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation)


class ECWrapper(torch.nn.Module):
class ModuleWrapper(torch.nn.Module):
"""
Wrapper Module for benchmarking EC Modules
A wrapper for nn.modules that allows them to accept inputs
of type KeyedJaggedTensor or ModelInput and forwards them to the
underlying module. This wrapper is necessary to provide compatibility
with FX tracing.

Args:
module: module to benchmark

Call Args:
input: KeyedJaggedTensor KJT input to module

Returns:
output: KT output from module
module: The torch.nn.Module to be wrapped.

Example:
e1_config = EmbeddingConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
import torch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.distributed.benchmark.benchmark_utils import ModuleWrapper

# Create a simple module
module = torch.nn.Linear(10, 5)
wrapped_module = ModuleWrapper(module)

# Create a KeyedJaggedTensor input
values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
weights = None
lengths = torch.tensor([2, 0, 1, 1, 3, 1])
offsets = torch.tensor([0, 2, 2, 3, 4, 7, 8])
keys = ["F1", "F2", "F3"]
kjt = KeyedJaggedTensor(
values=values,
weights=weights,
lengths=lengths,
offsets=offsets,
keys=keys,
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ec.qconfig = torch.quantization.QConfig(
activation=torch.quantization.PlaceholderObserver.with_args(
dtype=torch.qint8
),
weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qec = QuantEmbeddingCollection.from_float(ecc)

wrapped_module = ECWrapper(qec)
quantized_embeddings = wrapped_module(features)
# Forward the input through the wrapped module
output = wrapped_module(kjt)
"""

def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self._module = module

def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
"""
Args:
input (KeyedJaggedTensor): KJT of form [F X B X L].

Returns:
Dict[str, JaggedTensor]
def forward(self, input: Union[KeyedJaggedTensor, ModelInput]) -> Any: # pyre-ignore[3]
"""
return self._module.forward(input)


class EBCWrapper(torch.nn.Module):
"""
Wrapper Module for benchmarking Modules

Args:
module: module to benchmark

Call Args:
input: KeyedJaggedTensor KJT input to module

Returns:
output: KT output from module

Example:
table_0 = EmbeddingBagConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
Forward pass of the wrapped module.

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
activation=torch.quantization.PlaceholderObserver.with_args(
dtype=torch.qint8
),
weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)

wrapped_module = EBCWrapper(qebc)
quantized_embeddings = wrapped_module(features)
"""

def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self._module = module

def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
"""
Args:
input (KeyedJaggedTensor): KJT of form [F X B X L].
input: Input of type KeyedJaggedTensor or ModelInput to be forwarded to the underlying module.

Returns:
KeyedTensor
The output from the underlying module's forward pass.
"""
return self._module.forward(input)

Expand Down Expand Up @@ -369,11 +303,31 @@ def get_inputs(
batch_size: int,
world_size: int,
num_inputs: int,
num_float_features: int,
train: bool,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[List[KeyedJaggedTensor]]:
inputs_batch: List[List[KeyedJaggedTensor]] = []
only_kjt: bool = False,
) -> Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]]:
"""
Generate inputs for benchmarking.

Args:
tables: List of embedding tables configurations
batch_size: Batch size for generated inputs
world_size: Number of ranks/processes
num_inputs: Number of input batches to generate
num_float_features: Number of float features
train: Whether inputs are for training
pooling_configs: Optional pooling factors for tables
variable_batch_embeddings: Whether to use variable batch size
only_kjt: If True, return KeyedJaggedTensor instead of ModelInput

Returns:
If only_kjt is False: List of lists of ModelInput objects
If only_kjt is True: List of lists of KeyedJaggedTensor objects
"""
inputs_batch = []

if variable_batch_embeddings and not train:
raise RuntimeError("Variable batch size is only supported in training mode")
Expand All @@ -383,14 +337,14 @@ def get_inputs(
_, model_input_by_rank = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=0,
num_float_features=num_float_features,
tables=tables,
)
else:
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
num_float_features=num_float_features,
tables=tables,
weighted_tables=[],
tables_pooling=pooling_configs,
Expand All @@ -399,21 +353,31 @@ def get_inputs(
)

if train:
sparse_features_by_rank = [
model_input.idlist_features
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
]
inputs_batch.append(sparse_features_by_rank)
inputs_batch.append(
[
model_input
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
or not only_kjt
]
)
else:
sparse_features = model_input_by_rank[0].idlist_features
assert isinstance(sparse_features, KeyedJaggedTensor)
inputs_batch.append([sparse_features])
assert (
isinstance(model_input_by_rank[0].idlist_features, KeyedJaggedTensor)
or not only_kjt
)
inputs_batch.append([model_input_by_rank[0]])

# If only_kjt is True, extract idlist_features from ModelInput objects
if only_kjt:
inputs_batch = [
[model_input.idlist_features for model_input in batch]
for batch in inputs_batch
]

# Transpose if train, as inputs_by_rank is currently in [B X R] format
inputs_by_rank = [
[sparse_features for sparse_features in sparse_features_rank]
for sparse_features_rank in zip(*inputs_batch)
list(model_inputs_rank) for model_inputs_rank in zip(*inputs_batch)
]

return inputs_by_rank
Expand Down Expand Up @@ -630,7 +594,7 @@ def init_argparse_and_args() -> argparse.Namespace:
def transform_module(
module: torch.nn.Module,
device: torch.device,
inputs: List[KeyedJaggedTensor],
inputs: Union[List[ModelInput], List[KeyedJaggedTensor]],
sharder: ModuleSharder[T],
sharding_type: ShardingType,
compile_mode: CompileMode,
Expand Down Expand Up @@ -876,9 +840,13 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
def benchmark(
name: str,
model: torch.nn.Module,
warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
warmup_inputs: Union[
List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]]
],
bench_inputs: Union[
List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]]
],
prof_inputs: Union[List[KeyedJaggedTensor], List[ModelInput], List[Dict[str, Any]]],
world_size: int,
output_dir: str,
num_benchmarks: int,
Expand Down Expand Up @@ -994,9 +962,9 @@ def init_module_and_run_benchmark(
compile_mode: CompileMode,
world_size: int,
batch_size: int,
warmup_inputs: List[List[KeyedJaggedTensor]],
bench_inputs: List[List[KeyedJaggedTensor]],
prof_inputs: List[List[KeyedJaggedTensor]],
warmup_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]],
bench_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]],
prof_inputs: Union[List[List[ModelInput]], List[List[KeyedJaggedTensor]]],
tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]],
output_dir: str,
num_benchmarks: int,
Expand Down Expand Up @@ -1056,7 +1024,7 @@ def init_module_and_run_benchmark(
module = transform_module(
module=module,
device=device,
inputs=warmup_inputs_cuda,
inputs=warmup_inputs_cuda, # pyre-ignore[6]
sharder=sharder,
sharding_type=sharding_type,
compile_mode=compile_mode,
Expand All @@ -1075,9 +1043,9 @@ def init_module_and_run_benchmark(
res = benchmark(
name,
module,
warmup_inputs_cuda,
bench_inputs_cuda,
prof_inputs_cuda,
warmup_inputs_cuda, # pyre-ignore[6]
bench_inputs_cuda, # pyre-ignore[6]
prof_inputs_cuda, # pyre-ignore[6]
world_size=world_size,
output_dir=output_dir,
num_benchmarks=num_benchmarks,
Expand Down Expand Up @@ -1167,6 +1135,7 @@ def benchmark_module(
warmup_iters: int = 20,
bench_iters: int = 500,
prof_iters: int = 20,
num_float_features: int = 0,
batch_size: int = 2048,
world_size: int = 2,
num_benchmarks: int = 5,
Expand All @@ -1177,6 +1146,7 @@ def benchmark_module(
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
device_type: str = "cuda",
train: bool = True,
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -1211,29 +1181,22 @@ def benchmark_module(
assert (
num_benchmarks > 2
), "num_benchmarks needs to be greater than 2 for statistical analysis"
if isinstance(module, QuantEmbeddingBagCollection) or isinstance(
module, QuantEmbeddingCollection
):
train = False
else:
train = True

benchmark_results: List[BenchmarkResult] = []

if isinstance(tables[0], EmbeddingBagConfig):
wrapped_module = EBCWrapper(module)
else:
wrapped_module = ECWrapper(module)
wrapped_module = ModuleWrapper(module)

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(
tables,
batch_size,
world_size,
num_inputs_to_gen,
num_float_features,
train,
pooling_configs,
variable_batch_embeddings,
only_kjt=True,
)

warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]
Expand Down Expand Up @@ -1288,9 +1251,9 @@ def benchmark_module(
compile_mode=compile_mode,
world_size=world_size,
batch_size=batch_size,
warmup_inputs=warmup_inputs,
bench_inputs=bench_inputs,
prof_inputs=prof_inputs,
warmup_inputs=warmup_inputs, # pyre-ignore[6]
bench_inputs=bench_inputs, # pyre-ignore[6]
prof_inputs=prof_inputs, # pyre-ignore[6]
tables=tables,
num_benchmarks=num_benchmarks,
output_dir=output_dir,
Expand Down
Loading