Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,26 @@ def aten_ops_amax(
)


@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
def aten_ops_sum(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.sum(
network,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, replacement=None),
args_bounds_check(args, 2, replacement=False),
)


@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
def aten_ops_exp(
network: TRTNetwork,
Expand Down
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Target
Expand Down Expand Up @@ -33,3 +33,29 @@ def amax(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def sum(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]] = None,
keepdim: bool = False,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)

if dim is None:
dim = tuple(range(len(input_val.shape)))
layer = network.add_reduce(
input_val,
trt.ReduceOperation.SUM,
axes=get_axes_for_reduce_op(dim),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
114 changes: 114 additions & 0 deletions tests/py/dynamo/conversion/test_sum_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestSumConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4),),
((2, 3, 4, 5),),
((2, 3, 4, 5),),
((6, 7, 5, 4, 5),),
]
)
def test_sum_dim_int_default(self, input_shape):
class Sum(nn.Module):
def forward(self, x):
return torch.sum(x)

inputs = [torch.randn(*input_shape)]
self.run_test(
Sum(),
inputs,
expected_ops={torch.ops.aten.sum.default},
)

@parameterized.expand(
[
((3, 2, 4), 1, True),
((2, 3, 4, 5), 3, True),
((2, 3, 4, 5), None, False),
((6, 7, 5, 4, 5), 4, False),
]
)
def test_sum_dim_int(self, input_shape, dim, keep_dims):
class Sum(nn.Module):
def forward(self, x):
return torch.sum(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape)]
self.run_test(
Sum(),
inputs,
expected_ops={torch.ops.aten.sum.dim_IntList},
)

@parameterized.expand(
[
((3, 2, 4), [1], True),
((2, 1, 4, 5), None, True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
((6, 7, 5, 4, 5), [1, 3, 4], False),
]
)
def test_sum_dim_tuple(self, input_shape, dim, keep_dims):
class Sum(nn.Module):
def forward(self, x):
return torch.sum(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape)]
self.run_test(
Sum(),
inputs,
expected_ops={torch.ops.aten.sum.dim_IntList},
)

@parameterized.expand(
[
((3, 2, 4), 1, True, torch.int, 0, 5),
((2, 3, 4, 5), None, True, torch.int, -10, 10),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
]
)
def test_sum_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Sum(nn.Module):
def forward(self, x):
return torch.sum(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Sum(),
inputs,
expected_ops={torch.ops.aten.sum.dim_IntList},
check_dtype=False,
)

@parameterized.expand(
[
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), None, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
]
)
def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Sum(nn.Module):
def forward(self, x):
return torch.sum(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Sum(),
inputs,
expected_ops={torch.ops.aten.sum.dim_IntList},
check_dtype=False,
)


if __name__ == "__main__":
run_tests()