20
20
TRTPluginFieldCollection ,
21
21
TRTTensor ,
22
22
)
23
- from ..utils import torch_dtype_from_trt
23
+ from ..utils import unified_dtype_converter , Frameworks
24
24
25
25
26
26
class SourceIR (Enum ):
@@ -151,38 +151,49 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int):
151
151
return extend_attr_to_tuple (val , size )
152
152
153
153
154
- def to_numpy (value : Optional [Union [torch .Tensor , int , float ]]) -> Optional [np .ndarray ]:
154
+ def to_numpy (
155
+ value : Optional [Union [torch .Tensor , np .ndarray , int , float ]],
156
+ dtype : Optional [Union [torch .dtype , np .dtype , TRTDataType ]] = None ,
157
+ ) -> Optional [np .ndarray ]:
155
158
"""
156
159
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
157
160
quantized it will be dequantized first.
158
161
159
162
Args:
160
- value (Optional[Union[torch.Tensor, int, float]]): A PyTorch tensor, int, or float
163
+ value (Optional[Union[torch.Tensor, np.ndarray, int, float]]):
164
+ A PyTorch tensor, Numpy array, int, or float
161
165
162
166
Returns:
163
167
A Numpy array.
164
168
"""
169
+ output = None
165
170
166
- if value is None :
167
- return value
171
+ if value is None or isinstance ( value , np . ndarray ) :
172
+ output = value
168
173
169
174
elif isinstance (value , torch .Tensor ):
170
175
if value .is_quantized :
171
176
value = value .dequantize ()
172
177
173
- return value .cpu ().detach ().contiguous ().numpy ()
178
+ output = value .cpu ().detach ().contiguous ().numpy ()
174
179
175
180
elif isinstance (value , int ):
176
- return np .array ([value ], dtype = np .int32 )
181
+ output = np .array ([value ], dtype = np .int32 )
177
182
178
183
elif isinstance (value , float ):
179
- return np .array ([value ], dtype = np .float32 )
184
+ output = np .array ([value ], dtype = np .float32 )
180
185
181
186
else :
182
187
raise AssertionError (
183
- f"to_numpy can only be called on None, int, float, or torch.Tensor, got: { value } "
188
+ f"to_numpy can only be called on None, int, float, np.ndarray, or torch.Tensor, got: { value } "
184
189
)
185
190
191
+ return (
192
+ output
193
+ if dtype is None
194
+ else output .astype (unified_dtype_converter (dtype , Frameworks .NUMPY ))
195
+ )
196
+
186
197
187
198
def has_dynamic_shape (shape : Shape ) -> bool :
188
199
"""
@@ -234,35 +245,35 @@ def get_axes_for_reduce_op(
234
245
235
246
def create_constant (
236
247
network : TRTNetwork ,
237
- value : Union [int , float , torch .Tensor ],
248
+ value : Union [int , float , np . ndarray , torch .Tensor ],
238
249
name : str ,
239
- dtype : Optional [torch .dtype ],
250
+ dtype : Optional [Union [ torch .dtype , np . dtype , TRTDataType ] ],
240
251
) -> TRTTensor :
241
252
"""
242
253
Add a TensorRT constant layer whose value is `value` to `network`.
243
254
244
255
Args:
245
256
network (TRTNetwork): A TensorRT network to which we want to add
246
257
a constant layer.
247
- value (Union[int, float, torch.Tensor]): A literal value or a PyTorch tensor
248
- that will be used as value of the added TensorRT Constant layer.
258
+ value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array,
259
+ or a PyTorch tensor that will be used as value of the added TensorRT Constant layer.
249
260
name (str): Name of the added TensorRT Constant layer.
250
- dtype (Optional[torch.dtype]): If a dtype is given, we will convert the type
251
- of the given `value` to this dtype.
261
+ dtype (Optional[Union[ torch.dtype, np. dtype, TRTDataType]]):
262
+ If a dtype is given, we will convert the type of the given `value` to this dtype.
252
263
253
264
Returns:
254
265
A TensorRT ITensor that represents the given value.
255
266
"""
256
-
257
- if dtype :
258
- value = value .to (dtype )
259
- constant = network .add_constant (value .shape , to_numpy (value ))
267
+ constant = network .add_constant (value .shape , to_numpy (value , dtype ))
260
268
constant .name = name
261
269
return constant .get_output (0 )
262
270
263
271
264
272
def get_trt_tensor (
265
- network : TRTNetwork , input_val : Any , name : str , dtype : Optional [torch .dtype ] = None
273
+ network : TRTNetwork ,
274
+ input_val : Any ,
275
+ name : str ,
276
+ dtype : Optional [Union [torch .dtype , np .dtype , TRTDataType ]] = None ,
266
277
) -> TRTTensor :
267
278
"""
268
279
Given a value of random type, we try to convert it to a TensorRT ITensor.
@@ -274,33 +285,36 @@ def get_trt_tensor(
274
285
input_val (Any): An value that we want to convert to a TensorRT ITensor.
275
286
name (str): The name of the created TensorRT Constant layer if there's
276
287
one.
277
- dtype (Optional[torch.dtype]): If dtype is provided, the given value
278
- will be converted to this dtype.
288
+ dtype (Optional[Union[ torch.dtype, np. dtype, TRTDataType]]):
289
+ If dtype is provided, the given value will be converted to this dtype.
279
290
280
291
Returns:
281
292
A TensorRT ITensor that represents the given value.
282
293
"""
283
294
# TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later
284
295
# This is useful for logical operations which require input to be bool type
285
- if isinstance (input_val , np .ndarray ):
286
- input_val = torch .from_numpy (input_val )
287
296
if isinstance (input_val , bool ):
288
297
input_val = int (input_val )
289
- if isinstance (input_val , torch .Tensor ) and input_val .dtype == torch .bool :
290
- input_val = input_val .to (torch .int32 )
291
- if isinstance (input_val , torch .Tensor ) and input_val .dtype == torch .int64 :
298
+
299
+ if isinstance (input_val , torch .Tensor ) and (
300
+ input_val .dtype == torch .bool or input_val .dtype == torch .int64
301
+ ):
292
302
input_val = input_val .to (torch .int32 )
303
+ elif isinstance (input_val , np .ndarray ) and (
304
+ input_val .dtype == np .bool or input_val .dtype == np .int64
305
+ ):
306
+ input_val = input_val .to (np .int32 )
293
307
294
- if isinstance (input_val , (torch .Tensor , int , float )):
308
+ if isinstance (input_val , (torch .Tensor , np . ndarray , int , float )):
295
309
return create_constant (network , input_val , name , dtype )
296
- elif not isinstance (input_val , TRTTensor ):
297
- raise RuntimeError (
298
- f"Received input { input_val } of name { name } that "
299
- "is not part of the TensorRT region!"
300
- )
301
- else :
310
+ elif isinstance (input_val , TRTTensor ):
302
311
return input_val
303
312
313
+ raise RuntimeError (
314
+ f"Received input { input_val } of name { name } that "
315
+ "is not part of the TensorRT region!"
316
+ )
317
+
304
318
305
319
def prepend_ones (
306
320
network : TRTNetwork ,
@@ -482,10 +496,10 @@ def add_binary_elementwise_layer(
482
496
is_rhs_trt_tensor = False
483
497
484
498
if isinstance (lhs_val , TRTTensor ):
485
- lhs_dtype = torch_dtype_from_trt (lhs_val .dtype )
499
+ lhs_dtype = unified_dtype_converter (lhs_val .dtype , Frameworks . TORCH )
486
500
is_lhs_trt_tensor = True
487
501
if isinstance (rhs_val , TRTTensor ):
488
- rhs_dtype = torch_dtype_from_trt (rhs_val .dtype )
502
+ rhs_dtype = unified_dtype_converter (rhs_val .dtype , Frameworks . TORCH )
489
503
is_rhs_trt_tensor = True
490
504
491
505
if not is_lhs_trt_tensor and not is_rhs_trt_tensor :
@@ -510,9 +524,13 @@ def add_binary_elementwise_layer(
510
524
# dtype but we don't have a way to detect whether it makes sense for the
511
525
# scalar to be float or half. Hence we go with the lhs dtype.
512
526
if is_lhs_trt_tensor and isinstance (rhs_val , (float , int )):
513
- rhs_val = torch .tensor ([rhs_val ], dtype = lhs_dtype )
527
+ rhs_val = np .array (
528
+ [rhs_val ], dtype = unified_dtype_converter (lhs_val .dtype , Frameworks .NUMPY )
529
+ )
514
530
if is_rhs_trt_tensor and isinstance (lhs_val , (float , int )):
515
- lhs_val = torch .tensor ([lhs_val ], dtype = rhs_dtype )
531
+ lhs_val = np .array (
532
+ [lhs_val ], dtype = unified_dtype_converter (rhs_val .dtype , Frameworks .NUMPY )
533
+ )
516
534
517
535
# When lhs is scalar, and rhs has shape [1,], then currently the assert
518
536
# will fail because lhs shape has fewer dimensions than rhs shape. This
@@ -552,14 +570,19 @@ def add_binary_elementwise_layer(
552
570
return output
553
571
554
572
555
- def squeeze_left (const : torch .Tensor ):
573
+ def squeeze_left (const : Union [ torch .Tensor , np . ndarray ] ):
556
574
"""
557
575
Squeeze the size-1 dimensions on the left side of the shape tuple.
558
576
PyTorch's `squeeze()` doesn't support passing multiple `dim`s at once, so
559
577
we do it iteratively.
560
578
"""
561
579
while len (const .shape ) > 0 and const .shape [0 ] == 1 :
562
- const = const .squeeze (dim = 0 )
580
+ if isinstance (const , torch .Tensor ):
581
+ const = const .squeeze (dim = 0 )
582
+ elif isinstance (const , np .ndarray ):
583
+ const = const .squeeze (axis = 0 )
584
+ else :
585
+ raise AssertionError (f"Expected torch Tensor or Numpy array, got: { const } " )
563
586
return const
564
587
565
588
@@ -786,7 +809,10 @@ def trunc_div(
786
809
input = get_trt_tensor (network , input , f"{ name } _input" )
787
810
if not isinstance (other , trt .tensorrt .ITensor ):
788
811
other = get_trt_tensor (
789
- network , other , f"{ name } _other" , dtype = torch_dtype_from_trt (input .dtype )
812
+ network ,
813
+ other ,
814
+ f"{ name } _other" ,
815
+ dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH ),
790
816
)
791
817
792
818
abs_input_output = add_unary_layer (
@@ -875,13 +901,3 @@ def type_cast(
875
901
layer_i .set_output_type (0 , cast_type )
876
902
set_layer_name (layer_i , target , f"{ name } _dtype_change" )
877
903
return layer_i .get_output (0 )
878
-
879
-
880
- def trt_dtype_to_torch_dtype (trt_dtype ):
881
- table = {
882
- trt .bool : torch .bool ,
883
- trt .int32 : torch .int32 ,
884
- trt .float16 : torch .float16 ,
885
- trt .float32 : torch .float32 ,
886
- }
887
- return table [trt_dtype ]
0 commit comments