66import torch
77from torch .fx .node import Target
88from torch_tensorrt .dynamo ._SourceIR import SourceIR
9- from torch_tensorrt .dynamo .conversion .impl .elementwise .base import (
10- convert_binary_elementwise ,
11- )
12- from torch_tensorrt .dynamo .conversion .impl .unary .base import convert_unary
9+ from torch_tensorrt .dynamo .conversion import impl
10+ from torch_tensorrt .dynamo .conversion .converter_utils import get_axes_for_reduce_op
1311from torch_tensorrt .fx .converters .converter_utils import (
1412 get_positive_dim ,
1513 get_trt_plugin ,
14+ get_trt_tensor ,
1615 has_dynamic_shape ,
1716 set_layer_name ,
1817 to_numpy ,
@@ -29,10 +28,10 @@ def batch_norm(
2928 source_ir : Optional [SourceIR ],
3029 name : str ,
3130 input : TRTTensor ,
32- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
33- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
34- running_mean : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
35- running_var : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
31+ weight : Optional [Union [torch .Tensor , np .ndarray ]],
32+ bias : Optional [Union [torch .Tensor , np .ndarray ]],
33+ running_mean : Optional [Union [torch .Tensor , np .ndarray ]],
34+ running_var : Optional [Union [torch .Tensor , np .ndarray ]],
3635 training : bool ,
3736 momentum : float ,
3837 eps : float ,
@@ -103,8 +102,8 @@ def layer_norm(
103102 name : str ,
104103 input : TRTTensor ,
105104 normalized_shape : List [int ],
106- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
107- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
105+ weight : Optional [Union [torch .Tensor , np .ndarray ]],
106+ bias : Optional [Union [torch .Tensor , np .ndarray ]],
108107 eps : float ,
109108 cudnn_enable : bool ,
110109) -> Union [TRTTensor , Sequence [TRTTensor ]]:
@@ -170,8 +169,8 @@ def layer_norm_no_plugin(
170169 name : str ,
171170 input : TRTTensor ,
172171 normalized_shape : List [int ],
173- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
174- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
172+ weight : Optional [Union [torch .Tensor , np .ndarray ]],
173+ bias : Optional [Union [torch .Tensor , np .ndarray ]],
175174 eps : float ,
176175) -> Union [TRTTensor , Sequence [TRTTensor ]]:
177176 if not isinstance (input , TRTTensor ):
@@ -188,79 +187,77 @@ def layer_norm_no_plugin(
188187
189188 shape = weight .shape
190189 broadcasted_shape = (1 ,) * (len (input .shape ) - len (shape )) + shape
191- gamma = to_numpy (weight .reshape (* shape ) )
192- beta = to_numpy (bias .reshape (* shape ) )
190+ gamma = to_numpy (weight ) .reshape (shape )
191+ beta = to_numpy (bias ) .reshape (shape )
193192
194- axes = 0
195- for d in range (len (shape )):
196- axes |= 1 << (len (input .shape ) - d - 1 )
193+ dims = list (range (len (input .shape ) - len (shape ), len (input .shape )))
194+ axes = get_axes_for_reduce_op (dims )
197195
198196 # E[x]
199197 mean_expected_layer = network .add_reduce (
200198 input , trt .ReduceOperation .AVG , axes , keep_dims = True
201199 )
202200 set_layer_name (mean_expected_layer , target , f"{ name } _mean_expected" , source_ir )
203201
204- # X- E[x]
205- sub_trt = convert_binary_elementwise (
202+ # X - E[x]
203+ sub_trt = impl . elementwise . sub (
206204 network ,
207205 target ,
208206 source_ir ,
209207 f"{ name } _sub" ,
210- trt .ElementWiseOperation .SUB ,
211208 input ,
212209 mean_expected_layer .get_output (0 ),
213210 )
214- # Variance = mean(pow(x_sub_mean,2))
211+
212+ # variance = mean(pow(x_sub_mean, 2))
215213 pow_tensor = network .add_constant (
216214 (1 ,) * len (input .shape ),
217215 trt .Weights (np .ascontiguousarray ([2.0 ], dtype = np .float32 )),
218216 )
219217 pow_tensor .name = f"{ name } _power"
220- pow_var = convert_binary_elementwise (
218+ pow_var = impl . elementwise . pow (
221219 network ,
222220 target ,
223221 source_ir ,
224222 f"{ name } _pow_var" ,
225- trt .ElementWiseOperation .POW ,
226223 sub_trt ,
227224 pow_tensor .get_output (0 ),
228225 )
229226 mean_trt_layer = network .add_reduce (
230227 pow_var , trt .ReduceOperation .AVG , axes , keep_dims = True
231228 )
232229 set_layer_name (mean_trt_layer , target , f"{ name } _mean" , source_ir )
233- # Variance + eps
230+
231+ # var + eps
234232 eps_tensor = network .add_constant (
235233 (1 ,) * len (input .shape ),
236234 trt .Weights (np .ascontiguousarray ([eps ], dtype = np .float32 )),
237235 )
238236 eps_tensor .name = f"{ name } _eps"
239- add_trt = convert_binary_elementwise (
237+
238+ # sqrt((var + eps))
239+ add_trt = impl .elementwise .add (
240240 network ,
241241 target ,
242242 source_ir ,
243243 f"{ name } _add" ,
244- trt .ElementWiseOperation .SUM ,
245244 mean_trt_layer .get_output (0 ),
246245 eps_tensor .get_output (0 ),
247246 )
248- # SQRT((Var + eps))
249- sqrt_trt = convert_unary (
247+ sqrt_trt = impl .unary .sqrt (
250248 network ,
251249 target ,
252250 source_ir ,
253251 f"{ name } _sqrt" ,
254- trt .UnaryOperation .SQRT ,
255252 add_trt ,
256253 )
257- # (x - E[x]) / sqrt((var + eps))
258- div_trt = convert_binary_elementwise (
254+
255+ # (X - E[X]) / sqrt((var + eps))
256+ div_trt = impl .elementwise .div (
259257 network ,
260258 target ,
261259 source_ir ,
262260 f"{ name } _div_trt" ,
263- trt .ElementWiseOperation .DIV ,
264261 sub_trt ,
265262 sqrt_trt ,
266263 )
@@ -270,32 +267,113 @@ def layer_norm_no_plugin(
270267 gamma .shape , trt .Weights (np .ascontiguousarray (gamma ))
271268 )
272269 gamma_tensor .name = f"{ name } _gamma"
270+
273271 assert beta is not None
274272 beta_tensor = network .add_constant (
275273 gamma .shape , trt .Weights (np .ascontiguousarray (beta ))
276274 )
277275 beta_tensor .name = f"{ name } _beta"
276+
278277 # y * gamma + beta
279- scale_layer = convert_binary_elementwise (
278+ scaled_y = impl . elementwise . mul (
280279 network ,
281280 target ,
282281 source_ir ,
283282 f"{ name } _scale" ,
284- trt .ElementWiseOperation .PROD ,
285283 div_trt ,
286284 gamma_tensor .get_output (0 ),
287285 )
288- return convert_binary_elementwise (
286+ return impl . elementwise . add (
289287 network ,
290288 target ,
291289 source_ir ,
292290 name ,
293- trt .ElementWiseOperation .SUM ,
294- scale_layer ,
291+ scaled_y ,
295292 beta_tensor .get_output (0 ),
296293 )
297294
298295
296+ def native_group_norm (
297+ network : TRTNetwork ,
298+ target : Target ,
299+ source_ir : Optional [SourceIR ],
300+ name : str ,
301+ input : TRTTensor ,
302+ weight : Optional [Union [torch .Tensor , np .ndarray ]],
303+ bias : Optional [Union [torch .Tensor , np .ndarray ]],
304+ N : int ,
305+ C : int ,
306+ HxW : int ,
307+ group : int ,
308+ eps : float ,
309+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
310+ return group_norm (
311+ network ,
312+ target ,
313+ source_ir ,
314+ name ,
315+ input ,
316+ group ,
317+ weight ,
318+ bias ,
319+ eps ,
320+ cudnn_enabled = True ,
321+ )
322+
323+
324+ def group_norm (
325+ network : TRTNetwork ,
326+ target : Target ,
327+ source_ir : Optional [SourceIR ],
328+ name : str ,
329+ input : TRTTensor ,
330+ num_groups : int ,
331+ weight : Optional [Union [torch .Tensor , np .ndarray ]],
332+ bias : Optional [Union [torch .Tensor , np .ndarray ]],
333+ eps : float ,
334+ cudnn_enabled : bool ,
335+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
336+ if not isinstance (input , trt .tensorrt .ITensor ):
337+ raise RuntimeError (
338+ f"LayerNorm received input { input } that is not part "
339+ "of the TensorRT region!"
340+ )
341+
342+ if weight is None :
343+ weight = to_numpy (1.0 )
344+
345+ if bias is None :
346+ bias = to_numpy (0.0 )
347+
348+ scale = get_trt_tensor (network , weight , "scale" )
349+ bias = get_trt_tensor (network , bias , "bias" )
350+
351+ eps_field = trt .PluginField (
352+ "eps" , np .array (eps , dtype = np .float32 ), trt .PluginFieldType .FLOAT32
353+ )
354+ num_groups_filed = trt .PluginField (
355+ "num_groups" , np .array (num_groups ), trt .PluginFieldType .INT32
356+ )
357+
358+ field_collection = trt .PluginFieldCollection ([eps_field , num_groups_filed ])
359+
360+ try :
361+ # Here's the schema of the plugin:
362+ # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
363+ plugin = get_trt_plugin ("GroupNormalizationPlugin" , field_collection , "1" )
364+ except AssertionError :
365+ _LOGGER .error (
366+ "Unable to find group norm plugin, fall back to TensorRT implementation."
367+ )
368+
369+ layer = network .add_plugin_v2 ([input , scale , bias ], plugin )
370+ set_layer_name (layer , target , f"{ name } _GroupNormalizationPlugin" , source_ir )
371+
372+ # PyTorch requires three return values: (out, mean, rstd)
373+ dummy_tensor = torch .tensor (0 )
374+ return layer .get_output (0 ), dummy_tensor , dummy_tensor
375+
376+
299377def softmax (
300378 network : TRTNetwork ,
301379 target : Target ,
0 commit comments