diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index e955c7278b..7986b3510d 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -280,6 +280,38 @@ def create_constant( return constant.get_output(0) +def cast_trt_tensor( + network: TRTNetwork, + input_val: TRTTensor, + dtype: TRTDataType, + name: str, +) -> TRTTensor: + """ + Given a TRT Tensor, convert that Tensor to the specified dtype + + Adds an Identity layer to the network which performs the conversion + + Args: + network (TRTNetwork): A TensorRT network + input_val (TRTTensor): A TRT Tensor to cast to a new data type + dtype (TRTDataType): The TRTDataType to cast the input Tensor to + name (str): Name of the calling layer + + Returns: + A TensorRT ITensor which has been casted to the specified dtype + """ + # + if input_val.dtype != dtype: + identity_layer = network.add_identity(input_val) + identity_layer.set_output_type(0, dtype) + identity_layer.name = ( + f"Cast ITensor {input_val.name} from {input_val.dtype} to {dtype} - {name}" + ) + return identity_layer.get_output(0) + else: + return input_val + + def get_trt_tensor( network: TRTNetwork, input_val: Any, diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py index e79d5048cb..a7ab189e72 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py @@ -17,6 +17,7 @@ broadcast, squeeze_left, get_trt_tensor, + cast_trt_tensor, ) @@ -52,6 +53,7 @@ def convert_binary_elementwise( introduce constant via .size() op. Other scenario should be const folded first. If any operand is not a trt tensor, we make it a trt constant layer while preserve its dtype. Then we broadcast these two inputs to have the same number of dimensions. + We also promote the types of the two tensors to avoid dtype errors in TRT. Limitation: If we are using implicit batch dim mode, the operand that is not a trt @@ -126,6 +128,17 @@ def convert_binary_elementwise( lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + promoted_type = torch.promote_types( + unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), + unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH), + ) + trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) + + if trt_promoted_type != lhs_val.dtype: + lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + if trt_promoted_type != rhs_val.dtype: + rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + # Check the limitation in the doc string. if network.has_implicit_batch_dimension: if is_lhs_trt_tensor and not is_rhs_trt_tensor: diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py index e122c2a414..a18556d4c4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py @@ -57,6 +57,23 @@ def forward(self, x): inputs = [torch.rand(1, 1) + 1] self.run_test(m, inputs, expected_ops={expected_op}) + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_mismatched_dtypes( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x.int(), x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [2 * torch.rand(1, 1, dtype=torch.float) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) def test_elementwise_ops_with_one_constant( self, name, orig_op: Callable, expected_op