Skip to content

Commit 60c576d

Browse files
committed
fix: Bugfixes and review comments
- Added regression test
1 parent 756a41a commit 60c576d

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

py/torch_tensorrt/dynamo/conversion/impl/argmax.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion import impl
67
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
78
from torch_tensorrt.dynamo.conversion.converter_utils import (
89
cast_trt_tensor,
910
flatten_dims,
1011
get_axes_for_reduce_op,
11-
)
12-
from torch_tensorrt.fx.converters.converter_utils import (
1312
get_positive_dim,
14-
set_layer_name,
1513
)
14+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1615
from torch_tensorrt.fx.types import TRTTensor
1716

18-
from . import squeeze
19-
2017

2118
def argmax(
2219
ctx: ConversionContext,
@@ -28,7 +25,7 @@ def argmax(
2825
keep_dim: bool = False,
2926
) -> TRTTensor:
3027
if input.dtype == trt.int32:
31-
input = cast_trt_tensor(ctx, input, trt.float32, name)
28+
input = cast_trt_tensor(ctx, input, trt.float32, name, source_ir)
3229

3330
# Three different cases here:
3431
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
@@ -37,40 +34,41 @@ def argmax(
3734
out = input
3835

3936
if dim is None:
40-
shuffle_layer = ctx.net.add_shuffle(input)
41-
shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1)
42-
set_layer_name(shuffle_layer, target, name + "_flatten")
43-
out = shuffle_layer.get_output(0)
37+
new_shape = (*flatten_dims(input, 0, -1), 1)
38+
out = impl.shuffle.reshape(
39+
ctx, target, source_ir, f"{name}_flatten", input, new_shape
40+
)
4441
elif len(input.shape) == 1:
45-
shuffle_layer = ctx.net.add_shuffle(input)
46-
shuffle_layer.reshape_dims = (*input.shape, 1)
47-
set_layer_name(shuffle_layer, target, name + "_broadcast")
48-
out = shuffle_layer.get_output(0)
42+
new_shape = (*input.shape, 1)
43+
out = impl.shuffle.reshape(
44+
ctx, target, source_ir, f"{name}_broadcast", input, new_shape
45+
)
4946

50-
reduce_mask = get_axes_for_reduce_op(0)
51-
if dim is not None:
52-
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape)))
47+
# Reduce over the flattened input if the dimension is None, otherwise the specified dimension
48+
reduce_mask = get_axes_for_reduce_op(
49+
get_positive_dim(dim if dim is not None else 0, len(out.shape))
50+
)
5351

5452
topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
55-
set_layer_name(topk_layer, target, name)
53+
set_layer_name(topk_layer, target, name, source_ir)
5654

5755
out = topk_layer.get_output(1)
5856

5957
if dim is None:
60-
out_shuffle_layer = ctx.net.add_shuffle(out)
61-
out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else ()
62-
set_layer_name(out_shuffle_layer, target, name + "_broadcast")
63-
out = out_shuffle_layer.get_output(0)
58+
new_shape = ((1,) * len(input.shape)) if keep_dim else ()
59+
out = impl.shuffle.reshape(
60+
ctx, target, source_ir, f"{name}_unflatten", out, new_shape
61+
)
6462
elif len(input.shape) == 1:
65-
out = squeeze.squeeze(
63+
out = impl.squeeze.squeeze(
6664
ctx,
6765
target,
68-
SourceIR.ATEN,
69-
name + "_squeeze",
66+
source_ir,
67+
f"{name}_squeeze",
7068
out,
71-
1 if keep_dim else [0, 1],
69+
1 if keep_dim else (0, 1),
7270
)
7371
elif not keep_dim:
74-
out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim)
72+
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)
7573

7674
return out

tests/py/dynamo/conversion/test_argmax_aten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class TestArgmaxConverter(DispatchTestCase):
2121
("dim_1_keep_dim_false", (3, 3), 1, False),
2222
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
2323
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
24+
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
2425
]
2526
)
2627
def test_argmax(self, _, input_shape, dim, keep_dim):

0 commit comments

Comments
 (0)