|
1 | 1 | import math |
2 | 2 | from typing import Optional |
3 | 3 |
|
| 4 | +import numpy as np |
4 | 5 | import tensorrt as trt |
5 | | -import torch |
6 | 6 | from torch.fx.node import Target |
7 | 7 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
8 | 8 | from torch_tensorrt.dynamo.conversion import impl |
|
13 | 13 | ) |
14 | 14 | from torch_tensorrt.dynamo.conversion.impl.slice.base import slice |
15 | 15 | from torch_tensorrt.fx.converters.converter_utils import ( |
16 | | - Frameworks, |
17 | 16 | has_dynamic_shape, |
18 | 17 | prepend_ones, |
19 | 18 | set_layer_name, |
20 | | - unified_dtype_converter, |
21 | 19 | ) |
22 | 20 | from torch_tensorrt.fx.types import Shape, TRTTensor |
23 | 21 |
|
@@ -130,18 +128,16 @@ def cumsum( |
130 | 128 | input_shape = input.shape |
131 | 129 | dim = get_positive_dim(dim, len(input_shape)) |
132 | 130 | loop = ctx.net.add_loop() |
133 | | - axis = torch.tensor(input_shape[dim], dtype=torch.int32) |
| 131 | + axis = np.array(input_shape[dim]) |
134 | 132 | trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit") |
135 | 133 | loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT) |
136 | 134 | iterator = loop.add_iterator(input, dim, reverse=False) |
137 | 135 | data = iterator.get_output(0) |
138 | 136 | new_dims = tuple(data.shape) |
139 | | - zero_tensor = torch.zeros( |
140 | | - new_dims, dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH) |
141 | | - ) |
142 | | - zero_tensor = get_trt_tensor(ctx, zero_tensor, f"{name}_initial_value") |
| 137 | + zeros = np.zeros(new_dims) |
| 138 | + zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value") |
143 | 139 |
|
144 | | - running_sum = loop.add_recurrence(zero_tensor) |
| 140 | + running_sum = loop.add_recurrence(zero_trttensor) |
145 | 141 | set_layer_name(running_sum, target, f"{name}_running_sum", source_ir) |
146 | 142 | running_sum_tensor = running_sum.get_output(0) |
147 | 143 |
|
|
0 commit comments