diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e542f1d417..4549fc3357 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -655,6 +655,38 @@ def aten_ops_dynamic_block_quantize_op( ) +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + # Currently, `attn_mask` is not supported + return args_bounds_check(node.args, 3) is None + + +@dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + capability_validator=attention_validator, + supports_dynamic_shapes=True, +) +def tensorrt_scaled_dot_product_attention( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.attention.scaled_dot_product_attention( + ctx, + target, + SourceIR.TORCHTRT_LOWERED, + name, + args[0], + args[1], + args[2], + args_bounds_check(args, 5, False), + kwargs.get("scale", None), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 10af2ad892..314571cb86 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,6 +2,7 @@ activation, addmm, arange, + attention, cast, cat, condition, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py new file mode 100644 index 0000000000..9cc4a30ccf --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -0,0 +1,165 @@ +import math +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # the lower triangle of the tensor means the rows greater than and equal to the cols + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 + ) + # get the rows + row_tensor = impl.elementwise.trunc_div( + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col + ) + # get the cols + col_tensor = impl.elementwise.fmod( + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col + ) + cond = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor + ) + return impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", cond, [row, col] + ) + + +def scaled_dot_product_attention( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + query: TRTTensor, + key: TRTTensor, + value: TRTTensor, + is_causal: bool, + scale: Optional[float], +) -> TRTTensor: + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, -2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) + + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) + + # this is to generate a tensor which has shape (L, S), type is int32 + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 + ) + shape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] + ) + + # since we want our attn_bias to be in float32, so cast it to float32 + shape_tensor = cast_trt_tensor( + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir + ) + + # initialize the attn_bias as the zeros tensor + attn_bias = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 + ) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + inf_tensor = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") + ) + cond = impl.elementwise.eq( + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) + ) + # mask out the certain part of the attn_bias + attn_bias = impl.condition.select( + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + ) + + scaled = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 46cb00f45c..bfa4af6df4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -96,4 +96,4 @@ def quantize( set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) dq_output = dequantize_layer.get_output(0) - return dq_output + return dq_output \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 02ecf98bfe..4aa6559713 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -11,6 +11,8 @@ ) from torch_tensorrt.dynamo.types import TRTTensor +from packaging import version as pkg_version + logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def unsqueeze( ) -> TRTTensor: from importlib.metadata import version - if version("tensorrt") < "10.7.0": + if pkg_version.parse(version("tensorrt")) < pkg_version.parse("10.7.0"): logger.warning( f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}" ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..553151da7a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -9,6 +9,7 @@ from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach @@ -23,6 +24,7 @@ repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, + lower_scaled_dot_product_attention, remove_assert_nodes, accumulate_fp32_matmul, remove_num_users_is_0_nodes, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..40fd587615 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,169 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/tools/perf/Flux/flex_perf.py b/tools/perf/Flux/flex_perf.py new file mode 100644 index 0000000000..b7aa608dd3 --- /dev/null +++ b/tools/perf/Flux/flex_perf.py @@ -0,0 +1,97 @@ +from time import time + +import register_sdpa +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +for i in range(torch.cuda.device_count()): + print(torch.cuda.get_device_properties(i).name) + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +) +pipe.to(DEVICE).to(torch.bfloat16) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + # "enabled_precisions": {torch.float16}, + use_explicit_typing: True, + "truncate_double": True, + "min_block_size": 1, + "debug": False, + # "use_python_runtime": True, + "immutable_weights": False, + "offload_module_to_cpu": True, +} + + +def generate_image(prompt, inference_step, batch_size=1, benchmark=False, iterations=1): + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + if benchmark: + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +pipe.to(torch.bfloat16) +torch.cuda.empty_cache() +# Warmup +generate_image(["Test"], 20) +print("Benchmark Original PyTorch Module Latency (bfloat16)") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) + +pipe.to(torch.float16) +print("Benchmark Original PyTorch Module Latency (float16)") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) + +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm + +start = time() +generate_image(["Test"], 2, batch_size=2) +end = time() +print("Time Elapse compilation:", end - start) +print() +print("Benchmark TRT Accelerated Latency") +for batch_size in range(1, 3): + generate_image(["Test"], 20, batch_size=batch_size, benchmark=True, iterations=3) +torch.cuda.empty_cache() diff --git a/tools/perf/Flux/flux_quantization.py b/tools/perf/Flux/flux_quantization.py new file mode 100644 index 0000000000..9619bd9e92 --- /dev/null +++ b/tools/perf/Flux/flux_quantization.py @@ -0,0 +1,268 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="debug mode", +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], + default="fp8", + help="Quantization data type to use (fp8 or int8 or fp4 or fp16 or bf16 or fp32)", +) + +parser.add_argument( + "--sdpa", + action="store_true", + default=False, + help="Register SDPA operator", +) + +parser.add_argument( + "--strong-typing", + action="store_true", + help="string type flag", +) + +args = parser.parse_args() +if args.sdpa: + import register_sdpa + +dtype = torch.float16 +ptq_config = None +use_explicit_typing = args.strong_typing +enabled_precisions = [ + torch.float32, +] + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + ( + enabled_precisions.extend([torch.float8_e4m3fn, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.FP8_DEFAULT_CFG +elif args.dtype == "int8": # int8 + ( + enabled_precisions.extend([torch.int8, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.INT8_DEFAULT_CFG +elif args.dtype == "fp4": + ptq_config = mtq.NVFP4_DEFAULT_CFG + use_explicit_typing = True +elif args.dtype == "fp16": + enabled_precisions.append(torch.float16) if not use_explicit_typing else None +elif args.dtype == "bf16": + dtype = torch.bfloat16 + ( + enabled_precisions.extend([torch.bfloat16, torch.float16]) + if not use_explicit_typing + else None + ) +elif args.dtype == "fp32": + dtype = torch.float32 +else: + raise ValueError(f"Invalid dtype: {args.dtype}") +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) + +total_params = sum(p.numel() for p in pipe.transformer.parameters()) +print(f"\n Total number of parameters: {total_params/1000/1000/1000}B") +if dtype in (torch.float16, torch.bfloat16): + total_size = total_params * 2 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") +elif dtype == torch.float32: + total_size = total_params * 4 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") + +if args.debug: + pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True + ) + +pipe.to(DEVICE).to(dtype) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +if ptq_config is not None: + backbone = mtq.quantize(backbone, ptq_config, forward_loop) + mtq.disable_quantizer(backbone, filter_func) +else: + print("No quantization config provided, skipping quantization") + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=dtype).to(DEVICE), + "encoder_hidden_states": torch.randn((batch_size, 512, 4096), dtype=dtype).to( + DEVICE + ), + "pooled_projections": torch.randn((batch_size, 768), dtype=dtype).to(DEVICE), + "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + truncate_double=True, + min_block_size=1, + debug=args.debug, + immutable_weights=True, + offload_module_to_cpu=True, + ) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") + +if not args.debug: + print(f"Benchmark TRT Module Latency at ({args.dtype}) started") + for batch_size in range(1, 9): + benchmark(["Test"], 20, batch_size=batch_size, iterations=3) + print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/tools/perf/Flux/flux_quantization_debug.py b/tools/perf/Flux/flux_quantization_debug.py new file mode 100644 index 0000000000..43c817ebf0 --- /dev/null +++ b/tools/perf/Flux/flux_quantization_debug.py @@ -0,0 +1,203 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8"], + default="fp8", + help="Quantization data type to use (fp8 or int8)", +) + +args = parser.parse_args() + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + enabled_precisions = {torch.float8_e4m3fn, torch.float16} + ptq_config = mtq.FP8_DEFAULT_CFG +else: # int8 + enabled_precisions = {torch.int8, torch.float16} + ptq_config = mtq.INT8_DEFAULT_CFG + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True +) + +pipe.to(DEVICE).to(torch.float16) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +backbone = mtq.quantize(backbone, ptq_config, forward_loop) +mtq.disable_quantizer(backbone, filter_func) + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( + DEVICE + ), + "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + +trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + use_explicit_typing=False, + truncate_double=True, + min_block_size=1, + debug=False, + # use_python_runtime=True, + immutable_weights=True, + offload_module_to_cpu=True, +) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") + + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/tools/perf/Flux/register_sdpa.py b/tools/perf/Flux/register_sdpa.py new file mode 100644 index 0000000000..9afbbda7d4 --- /dev/null +++ b/tools/perf/Flux/register_sdpa.py @@ -0,0 +1,184 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from sdpa_converter import * +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + +# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention +# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default +) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) + +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +@_aten_lowering_pass +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/tools/perf/Flux/sdpa_converter.py b/tools/perf/Flux/sdpa_converter.py new file mode 100644 index 0000000000..903324dff5 --- /dev/null +++ b/tools/perf/Flux/sdpa_converter.py @@ -0,0 +1,176 @@ +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, +) -> TRTTensor: + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + row_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] + ) + + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + col_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + ) + + mask = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + ) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, attn_mask, dropout_p, is_causal = args + logger.info( + "Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time" + ) + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + # If is_causal is True, we need to generate a causal mask + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + else: + scaled_add_attn_bias = scaled + + # Create a if condition to check if is_causal is True + if isinstance(is_causal, TRTTensor): + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + scaled_add_attn_bias = output_layer.get_output(0) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out