Skip to content

Commit bb5a7c9

Browse files
committed
feat: support cumsum dynamo converter
1 parent acc248b commit bb5a7c9

File tree

3 files changed

+141
-1
lines changed

3 files changed

+141
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,29 @@ def aten_ops_chunk(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default) # type: ignore[misc]
695+
@enforce_tensor_types(
696+
{
697+
0: (TRTTensor,),
698+
}
699+
) # type: ignore[misc]
700+
def aten_ops_cumsum(
701+
ctx: ConversionContext,
702+
target: Target,
703+
args: Tuple[Argument, ...],
704+
kwargs: Dict[str, Argument],
705+
name: str,
706+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
707+
return impl.slice.cumsum(
708+
ctx,
709+
target,
710+
SourceIR.ATEN,
711+
name,
712+
args[0],
713+
args[1],
714+
)
715+
716+
694717
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
695718
@enforce_tensor_types(
696719
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import math
22
from typing import Optional
33

4+
import tensorrt as trt
5+
import torch
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion import impl
69
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
get_positive_dim,
12+
get_trt_tensor,
13+
)
814
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
915
from torch_tensorrt.fx.converters.converter_utils import (
16+
Frameworks,
1017
has_dynamic_shape,
1118
prepend_ones,
1219
set_layer_name,
20+
unified_dtype_converter,
1321
)
1422
from torch_tensorrt.fx.types import Shape, TRTTensor
1523

@@ -157,3 +165,43 @@ def chunk(
157165
cnt += 1
158166

159167
return result
168+
169+
170+
def cumsum(
171+
ctx: ConversionContext,
172+
target: Target,
173+
source_ir: Optional[SourceIR],
174+
name: str,
175+
input: TRTTensor,
176+
dim: int,
177+
) -> TRTTensor:
178+
input_shape = input.shape
179+
dim = get_positive_dim(dim, len(input_shape))
180+
loop = ctx.net.add_loop()
181+
axis = np.array(input_shape[dim])
182+
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
183+
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
184+
iterator = loop.add_iterator(input, dim, reverse=False)
185+
data = iterator.get_output(0)
186+
new_dims = tuple(data.shape)
187+
zeros = np.zeros(new_dims)
188+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
189+
190+
running_sum = loop.add_recurrence(zero_trttensor)
191+
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
192+
running_sum_tensor = running_sum.get_output(0)
193+
194+
current_sum = impl.elementwise.add(
195+
ctx,
196+
target,
197+
source_ir,
198+
f"{name}_elementwise_add",
199+
data,
200+
running_sum_tensor,
201+
)
202+
running_sum.set_input(1, current_sum)
203+
204+
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
205+
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
206+
loop_output.set_input(1, trip_limit)
207+
return loop_output.get_output(0)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCumsumConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((1,), 0),
13+
((2,), 0),
14+
((3,), -1),
15+
]
16+
)
17+
def test_cumsum_1D(self, shape, dim):
18+
class Cumsum(nn.Module):
19+
def forward(self, x):
20+
return torch.ops.aten.cumsum.default(x, dim)
21+
22+
inputs = [torch.randn(shape)]
23+
self.run_test(
24+
Cumsum(),
25+
inputs,
26+
)
27+
28+
@parameterized.expand(
29+
[
30+
((3, 1), 0),
31+
((3, 1), 1),
32+
((2, 3), -1),
33+
((2, 3), -2),
34+
]
35+
)
36+
def test_cumsum_2D(self, shape, dims):
37+
class Cumsum(nn.Module):
38+
def forward(self, x):
39+
return torch.ops.aten.cumsum.default(x, dims)
40+
41+
inputs = [torch.randn(shape)]
42+
self.run_test(
43+
Cumsum(),
44+
inputs,
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
((4, 2, 3), 0),
50+
((4, 2, 3), 1),
51+
((1, 2, 3), 2),
52+
((1, 2, 3), -1),
53+
((1, 2, 3), -2),
54+
]
55+
)
56+
def test_cumsum_3D(self, shape, dims):
57+
class Cumsum(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.cumsum.default(x, dims)
60+
61+
inputs = [torch.randn(shape)]
62+
self.run_test(
63+
Cumsum(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)