3
3
import tensorrt as trt
4
4
from torch .fx .node import Target
5
5
from torch_tensorrt .dynamo ._SourceIR import SourceIR
6
+ from torch_tensorrt .dynamo .conversion import impl
6
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
7
8
from torch_tensorrt .dynamo .conversion .converter_utils import (
8
9
cast_trt_tensor ,
9
10
flatten_dims ,
10
11
get_axes_for_reduce_op ,
11
- )
12
- from torch_tensorrt .fx .converters .converter_utils import (
13
12
get_positive_dim ,
14
- set_layer_name ,
15
13
)
14
+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
16
15
from torch_tensorrt .fx .types import TRTTensor
17
16
18
- from . import squeeze
19
-
20
17
21
18
def argmax (
22
19
ctx : ConversionContext ,
@@ -28,7 +25,7 @@ def argmax(
28
25
keep_dim : bool = False ,
29
26
) -> TRTTensor :
30
27
if input .dtype == trt .int32 :
31
- input = cast_trt_tensor (ctx , input , trt .float32 , name )
28
+ input = cast_trt_tensor (ctx , input , trt .float32 , name , source_ir )
32
29
33
30
# Three different cases here:
34
31
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
@@ -37,40 +34,41 @@ def argmax(
37
34
out = input
38
35
39
36
if dim is None :
40
- shuffle_layer = ctx . net . add_shuffle ( input )
41
- shuffle_layer . reshape_dims = ( * flatten_dims ( input , 0 , - 1 ), 1 )
42
- set_layer_name ( shuffle_layer , target , name + " _flatten")
43
- out = shuffle_layer . get_output ( 0 )
37
+ new_shape = ( * flatten_dims ( input , 0 , - 1 ), 1 )
38
+ out = impl . shuffle . reshape (
39
+ ctx , target , source_ir , f" { name } _flatten", input , new_shape
40
+ )
44
41
elif len (input .shape ) == 1 :
45
- shuffle_layer = ctx . net . add_shuffle ( input )
46
- shuffle_layer . reshape_dims = ( * input . shape , 1 )
47
- set_layer_name ( shuffle_layer , target , name + " _broadcast")
48
- out = shuffle_layer . get_output ( 0 )
42
+ new_shape = ( * input . shape , 1 )
43
+ out = impl . shuffle . reshape (
44
+ ctx , target , source_ir , f" { name } _broadcast", input , new_shape
45
+ )
49
46
50
- reduce_mask = get_axes_for_reduce_op (0 )
51
- if dim is not None :
52
- reduce_mask = get_axes_for_reduce_op (get_positive_dim (dim , len (out .shape )))
47
+ # Reduce over the flattened input if the dimension is None, otherwise the specified dimension
48
+ reduce_mask = get_axes_for_reduce_op (
49
+ get_positive_dim (dim if dim is not None else 0 , len (out .shape ))
50
+ )
53
51
54
52
topk_layer = ctx .net .add_topk (out , trt .TopKOperation .MAX , 1 , reduce_mask )
55
- set_layer_name (topk_layer , target , name )
53
+ set_layer_name (topk_layer , target , name , source_ir )
56
54
57
55
out = topk_layer .get_output (1 )
58
56
59
57
if dim is None :
60
- out_shuffle_layer = ctx . net . add_shuffle ( out )
61
- out_shuffle_layer . reshape_dims = ( 1 ,) * len ( input . shape ) if keep_dim else ()
62
- set_layer_name ( out_shuffle_layer , target , name + "_broadcast" )
63
- out = out_shuffle_layer . get_output ( 0 )
58
+ new_shape = (( 1 ,) * len ( input . shape )) if keep_dim else ( )
59
+ out = impl . shuffle . reshape (
60
+ ctx , target , source_ir , f" { name } _unflatten" , out , new_shape
61
+ )
64
62
elif len (input .shape ) == 1 :
65
- out = squeeze .squeeze (
63
+ out = impl . squeeze .squeeze (
66
64
ctx ,
67
65
target ,
68
- SourceIR . ATEN ,
69
- name + " _squeeze" ,
66
+ source_ir ,
67
+ f" { name } _squeeze" ,
70
68
out ,
71
- 1 if keep_dim else [ 0 , 1 ] ,
69
+ 1 if keep_dim else ( 0 , 1 ) ,
72
70
)
73
71
elif not keep_dim :
74
- out = squeeze .squeeze (ctx , target , SourceIR . ATEN , name + " _squeeze" , out , dim )
72
+ out = impl . squeeze .squeeze (ctx , target , source_ir , f" { name } _squeeze" , out , dim )
75
73
76
74
return out
0 commit comments