From 2e4825f9bbd57fbf8bf511555048f7e836353836 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 22 Jul 2025 12:26:22 -0700 Subject: [PATCH 1/2] revert linear converter --- .../dynamo/conversion/aten_ops_converters.py | 19 +++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/linear.py | 56 +++++++++++++++++++ .../dynamo/lowering/_decomposition_groups.py | 1 + 4 files changed, 77 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/linear.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8c0706539c..f1a7f9a8fc 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3579,3 +3579,22 @@ def aten_ops_nonzero( name, args[0], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True) +def aten_ops_linear( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.linear.linear( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args_bounds_check(args, 2, None), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 10af2ad892..61728392da 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,6 +12,7 @@ embedding, full, grid, + linear, matmul, nccl_ops, normalization, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py new file mode 100644 index 0000000000..5e859a46d3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -0,0 +1,56 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.types import TRTTensor + + +def linear( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], +) -> TRTTensor: + # Process weight terms + if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(weight, (torch.Tensor, np.ndarray)): + weight = get_trt_tensor(ctx, weight, f"{name}_weight") + + # Process bias terms + if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(bias, (torch.Tensor, np.ndarray)): + bias = get_trt_tensor(ctx, bias, f"{name}_bias") + + # add IMatrixMultiplyLayer + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + f"{name}_matrix_multiply", + input, + weight, + input_matrix_op=trt.MatrixOperation.NONE, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if bias is not None: + # add bias + out = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add_bias", out, bias + ) + + return out diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 825be75076..52b541d3a8 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -171,6 +171,7 @@ aten.upsample_bilinear2d.vec, aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, + aten.linear.default, } From aae44281ce424346a5ccc13d772385929c5362de Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 22 Jul 2025 12:50:34 -0700 Subject: [PATCH 2/2] add unit tests --- .../py/dynamo/conversion/test_linear_aten.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_linear_aten.py diff --git a/tests/py/dynamo/conversion/test_linear_aten.py b/tests/py/dynamo/conversion/test_linear_aten.py new file mode 100644 index 0000000000..2426b7b42d --- /dev/null +++ b/tests/py/dynamo/conversion/test_linear_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLinearConverter(DispatchTestCase): + @parameterized.expand( + [ + (10, 10), + (10, 100), + (100, 10), + (100, 100), + ] + ) + def test_linear_converter(self, in_features, out_features): + class LinearModel(nn.Module): + def __init__(self, in_features, out_features): + super(LinearModel, self).__init__() + self.linear = nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x) + + model = LinearModel(in_features, out_features).eval().cuda() + inputs = [torch.randn(int(torch.randint(1, 20, (1,))), in_features).cuda()] + self.run_test(model, inputs, use_dynamo_tracer=True, enable_passes=True) + + def test_linear_with_dynamic_shape(self): + class LinearModel(torch.nn.Module): + def forward(self, x, weight, bias): + return torch.ops.aten.linear.default(x, weight, bias) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=(1, 10), + opt_shape=(10, 10), + max_shape=(100, 10), + ), + Input(dtype=torch.float32, shape=(20, 10)), + Input(dtype=torch.float32, shape=(20,)), + ] + + self.run_test_with_dynamic_shape( + LinearModel(), input_specs, use_dynamo_tracer=True, enable_passes=True + ) + + +if __name__ == "__main__": + run_tests()