From 95572caf031c76fca6b1fc53e0317ffa5c7ec1b7 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 5 Mar 2025 16:31:37 +0000 Subject: [PATCH 1/2] fix: conv parameter check failure --- py/torch_tensorrt/dynamo/conversion/impl/conv.py | 10 ++++++++-- tests/py/dynamo/conversion/test_convolution_aten.py | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 903745a38a..0d79e39ddf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -6,7 +6,6 @@ import tensorrt as trt import torch from torch.fx.node import Target - from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -45,6 +44,8 @@ def convNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution." + num_dims = len(input.shape) - 2 + if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d input = impl.unsqueeze.unsqueeze( @@ -104,7 +105,12 @@ def convNd( conv_layer.set_input(2, bias) # Cast certain fields to tuples, in accordance with TRT requirements - padding = (padding,) if isinstance(padding, int) else padding + if isinstance(padding, int): + padding = (padding,) * num_dims + elif isinstance(padding, (list, tuple)): + padding = tuple(padding) + if len(padding) == 1: + padding = (padding[0],) * num_dims stride = (stride,) if isinstance(stride, int) else stride dilation = (dilation,) if isinstance(dilation, int) else dilation diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 0f043d85cd..d8284f7311 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -1,7 +1,6 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests - from torch_tensorrt import Input from .harness import DispatchTestCase @@ -134,6 +133,8 @@ def forward(self, x): param("no_bias", 1, bias=False), ("tuple_parameters", 1, (1, 1), (1, 1)), param("non_zero_padding", 1, padding=1), + param("list_zero_padding", 1, padding=[0]), + param("list_non_padding", 1, padding=[1]), param("dilation", 1, dilation=2), param("groups", 1, groups=3), ] @@ -205,6 +206,8 @@ def forward(self, x): param("no_bias", 1, bias=False), ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), param("non_zero_padding", 1, padding=1), + param("list_zero_padding", 1, padding=[0]), + param("list_non_padding", 1, padding=[1]), param("dilation", 1, dilation=2), ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 # param("groups", 1, groups=3), From 294dcf765c2300e196ebe42ceaf78f93c114d0aa Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 13 Mar 2025 12:40:56 +0000 Subject: [PATCH 2/2] fix: 1d_conv flag and parameter check failure for stride and dilation --- .../dynamo/conversion/aten_ops_converters.py | 5 +++-- py/torch_tensorrt/dynamo/conversion/impl/conv.py | 16 ++++++++++++++-- .../dynamo/conversion/test_convolution_aten.py | 9 +++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2a9255ed68..9f5e7dc768 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2468,13 +2468,14 @@ def aten_ops_convolution( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: is_transposed = args[6] + is_conv1d = len(args[0].shape) == 3 if not is_transposed: return impl.conv.convNd( ctx, target, source_ir=SourceIR.ATEN, name=name, - is_conv1d=len(args[3]) == 1, + is_conv1d=is_conv1d, input=args[0], weight=args[1], bias=args_bounds_check(args, 2, None), @@ -2489,7 +2490,7 @@ def aten_ops_convolution( target, source_ir=SourceIR.ATEN, name=name, - is_deconv1d=len(args[3]) == 1, + is_deconv1d=is_conv1d, input=args[0], weight=args[1], bias=args_bounds_check(args, 2, None), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 0d79e39ddf..25419d7f60 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -111,8 +111,20 @@ def convNd( padding = tuple(padding) if len(padding) == 1: padding = (padding[0],) * num_dims - stride = (stride,) if isinstance(stride, int) else stride - dilation = (dilation,) if isinstance(dilation, int) else dilation + + if isinstance(stride, int): + stride = (stride,) * num_dims + elif isinstance(stride, (list, tuple)): + stride = tuple(stride) + if len(stride) == 1: + stride = (stride[0],) * num_dims + + if isinstance(dilation, int): + dilation = (dilation,) * num_dims + elif isinstance(dilation, (list, tuple)): + dilation = tuple(dilation) + if len(dilation) == 1: + dilation = (dilation[0],) * num_dims # Expand parameters manually for Conv1D computations if is_conv1d: diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index d8284f7311..78d7fc4cca 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -132,10 +132,13 @@ def forward(self, x): ("default", 1), param("no_bias", 1, bias=False), ("tuple_parameters", 1, (1, 1), (1, 1)), + param("list_stride", 2, stride=[2]), param("non_zero_padding", 1, padding=1), param("list_zero_padding", 1, padding=[0]), param("list_non_padding", 1, padding=[1]), - param("dilation", 1, dilation=2), + param("dilation", 2, dilation=3), + param("tuple_dilation", 2, dilation=(3, 3)), + param("list_dilation", 2, dilation=[3]), param("groups", 1, groups=3), ] ) @@ -205,10 +208,12 @@ def forward(self, x): ("default", 1), param("no_bias", 1, bias=False), ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("list_stride", 2, stride=[2]), param("non_zero_padding", 1, padding=1), param("list_zero_padding", 1, padding=[0]), param("list_non_padding", 1, padding=[1]), - param("dilation", 1, dilation=2), + param("dilation", 2, dilation=2), + param("list_dilation", 2, dilation=[2]), ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 # param("groups", 1, groups=3), ]