diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ab9629b0db..42880646c3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -2,7 +2,6 @@ import warnings from typing import Any, Callable, Optional, Union -import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -103,6 +102,14 @@ def convert_binary_elementwise( rhs_dtype = rhs_val.dtype is_rhs_trt_tensor = True + # Handle scalar tensor type promotion for elementwise operations + # When one operand is a scalar tensor (0-dimensional), promote its dtype to match the other operand + # This ensures consistent type handling in Torch elementwise operations + if isinstance(lhs_val, torch.Tensor) and len(lhs_val.shape) == 0: + lhs_dtype = rhs_dtype + if isinstance(rhs_val, torch.Tensor) and len(rhs_val.shape) == 0: + rhs_dtype = lhs_dtype + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: warnings.warn( f"Both operands of the binary elementwise op {name} " diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 83ea3dd99b..dfc917bc00 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -8,7 +8,6 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcast, - cast_trt_tensor, get_trt_tensor, set_layer_name, ) @@ -48,16 +47,6 @@ def matrix_multiply( input, other = broadcast( ctx, input, other, f"{name}_input", f"{name}_other", preset_diff ) - if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(input.dtype).to(torch.dtype), - _enums.dtype._from(other.dtype).to(torch.dtype), - ) - ) - trt_promoted_type = promoted_type.to(trt.DataType) - input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted") - other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted") layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) set_layer_name(layer, target, name, source_ir)