33import tensorrt as trt
44from torch .fx .node import Target
55from torch_tensorrt .dynamo ._SourceIR import SourceIR
6+ from torch_tensorrt .dynamo .conversion import impl
67from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
78from torch_tensorrt .dynamo .conversion .converter_utils import (
89 cast_trt_tensor ,
910 flatten_dims ,
1011 get_axes_for_reduce_op ,
11- )
12- from torch_tensorrt .fx .converters .converter_utils import (
1312 get_positive_dim ,
14- set_layer_name ,
1513)
14+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
1615from torch_tensorrt .fx .types import TRTTensor
1716
18- from . import squeeze
19-
2017
2118def argmax (
2219 ctx : ConversionContext ,
@@ -28,7 +25,7 @@ def argmax(
2825 keep_dim : bool = False ,
2926) -> TRTTensor :
3027 if input .dtype == trt .int32 :
31- input = cast_trt_tensor (ctx , input , trt .float32 , name )
28+ input = cast_trt_tensor (ctx , input , trt .float32 , name , source_ir )
3229
3330 # Three different cases here:
3431 # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
@@ -37,40 +34,41 @@ def argmax(
3734 out = input
3835
3936 if dim is None :
40- shuffle_layer = ctx . net . add_shuffle ( input )
41- shuffle_layer . reshape_dims = ( * flatten_dims ( input , 0 , - 1 ), 1 )
42- set_layer_name ( shuffle_layer , target , name + " _flatten")
43- out = shuffle_layer . get_output ( 0 )
37+ new_shape = ( * flatten_dims ( input , 0 , - 1 ), 1 )
38+ out = impl . shuffle . reshape (
39+ ctx , target , source_ir , f" { name } _flatten", input , new_shape
40+ )
4441 elif len (input .shape ) == 1 :
45- shuffle_layer = ctx . net . add_shuffle ( input )
46- shuffle_layer . reshape_dims = ( * input . shape , 1 )
47- set_layer_name ( shuffle_layer , target , name + " _broadcast")
48- out = shuffle_layer . get_output ( 0 )
42+ new_shape = ( * input . shape , 1 )
43+ out = impl . shuffle . reshape (
44+ ctx , target , source_ir , f" { name } _broadcast", input , new_shape
45+ )
4946
50- reduce_mask = get_axes_for_reduce_op (0 )
51- if dim is not None :
52- reduce_mask = get_axes_for_reduce_op (get_positive_dim (dim , len (out .shape )))
47+ # Reduce over the flattened input if the dimension is None, otherwise the specified dimension
48+ reduce_mask = get_axes_for_reduce_op (
49+ get_positive_dim (dim if dim is not None else 0 , len (out .shape ))
50+ )
5351
5452 topk_layer = ctx .net .add_topk (out , trt .TopKOperation .MAX , 1 , reduce_mask )
55- set_layer_name (topk_layer , target , name )
53+ set_layer_name (topk_layer , target , name , source_ir )
5654
5755 out = topk_layer .get_output (1 )
5856
5957 if dim is None :
60- out_shuffle_layer = ctx . net . add_shuffle ( out )
61- out_shuffle_layer . reshape_dims = ( 1 ,) * len ( input . shape ) if keep_dim else ()
62- set_layer_name ( out_shuffle_layer , target , name + "_broadcast" )
63- out = out_shuffle_layer . get_output ( 0 )
58+ new_shape = (( 1 ,) * len ( input . shape )) if keep_dim else ( )
59+ out = impl . shuffle . reshape (
60+ ctx , target , source_ir , f" { name } _unflatten" , out , new_shape
61+ )
6462 elif len (input .shape ) == 1 :
65- out = squeeze .squeeze (
63+ out = impl . squeeze .squeeze (
6664 ctx ,
6765 target ,
68- SourceIR . ATEN ,
69- name + " _squeeze" ,
66+ source_ir ,
67+ f" { name } _squeeze" ,
7068 out ,
71- 1 if keep_dim else [ 0 , 1 ] ,
69+ 1 if keep_dim else ( 0 , 1 ) ,
7270 )
7371 elif not keep_dim :
74- out = squeeze .squeeze (ctx , target , SourceIR . ATEN , name + " _squeeze" , out , dim )
72+ out = impl . squeeze .squeeze (ctx , target , source_ir , f" { name } _squeeze" , out , dim )
7573
7674 return out
0 commit comments