diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 451d218ee7..a6b323602a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -471,6 +471,26 @@ def aten_ops_amax( ) +@dynamo_tensorrt_converter(torch.ops.aten.sum.default) +@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) +def aten_ops_sum( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.reduce.sum( + network, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, replacement=None), + args_bounds_check(args, 2, replacement=False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] def aten_ops_exp( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 53070761dd..c57bba48ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import tensorrt as trt from torch.fx.node import Target @@ -33,3 +33,29 @@ def amax( ) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def sum( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: Optional[Union[int, Sequence[int]]] = None, + keepdim: bool = False, +) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + + if dim is None: + dim = tuple(range(len(input_val.shape))) + layer = network.add_reduce( + input_val, + trt.ReduceOperation.SUM, + axes=get_axes_for_reduce_op(dim), + keep_dims=keepdim, + ) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py new file mode 100644 index 0000000000..a6cfdc3b15 --- /dev/null +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSumConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4),), + ((2, 3, 4, 5),), + ((2, 3, 4, 5),), + ((6, 7, 5, 4, 5),), + ] + ) + def test_sum_dim_int_default(self, input_shape): + class Sum(nn.Module): + def forward(self, x): + return torch.sum(x) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sum(), + inputs, + expected_ops={torch.ops.aten.sum.default}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True), + ((2, 3, 4, 5), 3, True), + ((2, 3, 4, 5), None, False), + ((6, 7, 5, 4, 5), 4, False), + ] + ) + def test_sum_dim_int(self, input_shape, dim, keep_dims): + class Sum(nn.Module): + def forward(self, x): + return torch.sum(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sum(), + inputs, + expected_ops={torch.ops.aten.sum.dim_IntList}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True), + ((2, 1, 4, 5), None, True), + ((2, 3, 4, 5), [0, 1, 2, 3], False), + ((6, 7, 5, 4, 5), [1, 3, 4], False), + ] + ) + def test_sum_dim_tuple(self, input_shape, dim, keep_dims): + class Sum(nn.Module): + def forward(self, x): + return torch.sum(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sum(), + inputs, + expected_ops={torch.ops.aten.sum.dim_IntList}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True, torch.int, 0, 5), + ((2, 3, 4, 5), None, True, torch.int, -10, 10), + ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), + ] + ) + def test_sum_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Sum(nn.Module): + def forward(self, x): + return torch.sum(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Sum(), + inputs, + expected_ops={torch.ops.aten.sum.dim_IntList}, + check_dtype=False, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True, torch.int, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((2, 3, 4, 5), None, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ] + ) + def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Sum(nn.Module): + def forward(self, x): + return torch.sum(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Sum(), + inputs, + expected_ops={torch.ops.aten.sum.dim_IntList}, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests()