@@ -46,7 +46,23 @@ def get_ir(target: Target) -> SourceIR:
4646 return SourceIR .UNKNOWN
4747
4848
49+ def one_user_validator (node : Node ) -> bool :
50+ # Validate only one user, which is a getitem node that accesses the first element in the list
51+ return (
52+ len (node .users ) == 1
53+ and list (node .users )[0 ].target == operator .getitem
54+ and list (node .users )[0 ].args [1 ] == 0
55+ )
56+
57+
58+ @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
59+ @dynamo_tensorrt_converter (torch .ops .aten .batch_norm .default ) # type: ignore[misc]
4960@dynamo_tensorrt_converter (torch .ops .aten .batch_norm ) # type: ignore[misc]
61+ @enforce_tensor_types (
62+ {
63+ 0 : (TRTTensor ,),
64+ }
65+ ) # type: ignore[misc]
5066def aten_ops_batch_norm (
5167 ctx : ConversionContext ,
5268 target : Target ,
@@ -59,14 +75,103 @@ def aten_ops_batch_norm(
5975 target ,
6076 SourceIR .ATEN ,
6177 name ,
62- args [0 ],
63- args [1 ],
64- args [2 ],
65- args [3 ],
66- args [4 ],
67- args [5 ],
68- args [6 ],
69- args [7 ],
78+ input = args [0 ],
79+ weight = args [1 ],
80+ bias = args [2 ],
81+ running_mean = args [3 ],
82+ running_var = args [4 ],
83+ training = args [5 ],
84+ momentum = args [6 ],
85+ eps = args [7 ],
86+ cudnn_enabled = args_bounds_check (args , 8 , True ),
87+ return_mean_rstd = (target == torch .ops .aten .native_batch_norm .default ),
88+ )
89+
90+
91+ @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
92+ @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
93+ @dynamo_tensorrt_converter (torch .ops .aten .layer_norm ) # type: ignore[misc]
94+ @enforce_tensor_types (
95+ {
96+ 0 : (TRTTensor ,),
97+ }
98+ ) # type: ignore[misc]
99+ def aten_ops_layer_norm (
100+ ctx : ConversionContext ,
101+ target : Target ,
102+ args : Tuple [Argument , ...],
103+ kwargs : Dict [str , Argument ],
104+ name : str ,
105+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
106+ return impl .normalization .layer_norm (
107+ ctx ,
108+ target ,
109+ SourceIR .ATEN ,
110+ name ,
111+ input = args [0 ],
112+ normalized_shape = args [1 ],
113+ weight = args_bounds_check (args , 2 ),
114+ bias = args_bounds_check (args , 3 ),
115+ eps = args_bounds_check (args , 4 , 1e-05 ),
116+ cudnn_enable = args_bounds_check (args , 5 , True ),
117+ return_mean_rstd = (target == torch .ops .aten .native_layer_norm .default ),
118+ )
119+
120+
121+ @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
122+ @enforce_tensor_types (
123+ {
124+ 0 : (TRTTensor ,),
125+ }
126+ ) # type: ignore[misc]
127+ def aten_ops_native_group_norm (
128+ ctx : ConversionContext ,
129+ target : Target ,
130+ args : Tuple [Argument , ...],
131+ kwargs : Dict [str , Argument ],
132+ name : str ,
133+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
134+ return impl .normalization .native_group_norm (
135+ ctx ,
136+ target ,
137+ SourceIR .ATEN ,
138+ name ,
139+ input = args [0 ],
140+ weight = args [1 ],
141+ bias = args [2 ],
142+ N = args [3 ],
143+ C = args [4 ],
144+ HxW = args [5 ],
145+ group = args [6 ],
146+ eps = args [7 ],
147+ )
148+
149+
150+ @dynamo_tensorrt_converter (torch .ops .aten .group_norm .default ) # type: ignore[misc]
151+ @dynamo_tensorrt_converter (torch .ops .aten .group_norm ) # type: ignore[misc]
152+ @enforce_tensor_types (
153+ {
154+ 0 : (TRTTensor ,),
155+ }
156+ ) # type: ignore[misc]
157+ def aten_ops_group_norm (
158+ ctx : ConversionContext ,
159+ target : Target ,
160+ args : Tuple [Argument , ...],
161+ kwargs : Dict [str , Argument ],
162+ name : str ,
163+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
164+ return impl .normalization .group_norm (
165+ ctx ,
166+ target ,
167+ SourceIR .ATEN ,
168+ name ,
169+ input = args [0 ],
170+ num_groups = args [1 ],
171+ weight = args_bounds_check (args , 2 , None ),
172+ bias = args_bounds_check (args , 3 , None ),
173+ eps = args_bounds_check (args , 4 , 1e-05 ),
174+ cudnn_enabled = args_bounds_check (args , 5 , True ),
70175 )
71176
72177
@@ -328,27 +433,6 @@ def aten_ops_matmul(
328433 )
329434
330435
331- @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
332- def aten_ops_layernorm (
333- ctx : ConversionContext ,
334- target : Target ,
335- args : Tuple [Argument , ...],
336- kwargs : Dict [str , Argument ],
337- name : str ,
338- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
339- return impl .normalization .layer_norm (
340- ctx ,
341- target ,
342- SourceIR .ATEN ,
343- name ,
344- args [0 ],
345- args [1 ],
346- args [2 ],
347- args [3 ],
348- args [4 ],
349- )
350-
351-
352436@dynamo_tensorrt_converter (torch .ops .aten .rsqrt .default ) # type: ignore[misc]
353437def aten_ops_rsqrt (
354438 ctx : ConversionContext ,
@@ -763,15 +847,6 @@ def aten_ops_prod(
763847 )
764848
765849
766- def one_user_validator (node : Node ) -> bool :
767- # Validate only one user, which is a getitem node that accesses the first element in the list
768- return (
769- len (node .users ) == 1
770- and list (node .users )[0 ].target == operator .getitem
771- and list (node .users )[0 ].args [1 ] == 0
772- )
773-
774-
775850@dynamo_tensorrt_converter (torch .ops .aten .max .default ) # type: ignore[misc]
776851@dynamo_tensorrt_converter (torch .ops .aten .max .dim , capability_validator = one_user_validator ) # type: ignore[misc]
777852def aten_ops_max (
0 commit comments