@@ -321,15 +321,16 @@ def cast_int_or_float_to_bool(
321
321
322
322
323
323
def to_trt_weights (
324
- value : Any ,
325
- record_weight : bool = False ,
326
- name : Optional [str ] = None ,
327
- ctx : Optional [ConversionContext ] = None ,
324
+ ctx : ConversionContext ,
325
+ value : torch .Tensor ,
326
+ name : str ,
327
+ layer_type_name : str ,
328
+ weight_type_name : str ,
328
329
target : Optional [Union [Target , str ]] = None ,
329
- layer_type_name : Optional [str ] = None ,
330
- weight_type_name : Optional [str ] = None ,
331
330
source_ir : Optional [SourceIR ] = None ,
332
331
target_quantized_type : Optional [trt .DataType ] = None ,
332
+ dtype : Optional [trt .DataType ] = None ,
333
+ count : Optional [int ] = None ,
333
334
) -> trt .Weights :
334
335
"""
335
336
Convert a PyTorch tensor or NumPy array to TensorRT weights.
@@ -344,57 +345,50 @@ def to_trt_weights(
344
345
- Input tensors are made contiguous before conversion
345
346
- Data type is preserved from the original tensor/array
346
347
"""
347
- if record_weight :
348
- assert name is not None , "name must be provided if record_weight is True"
349
- assert ctx is not None , "ctx must be provided if record_weight is True"
350
- assert target is not None , "target must be provided if record_weight is True"
351
- assert (
352
- layer_type_name is not None
353
- ), "layer_type_name must be provided if record_weight is True"
354
- assert (
355
- weight_type_name is not None
356
- ), "weight_type_name must be provided if record_weight is True"
348
+ if isinstance (value , np .ndarray ):
349
+ raise AssertionError (
350
+ f"to_trt_weights can only be called on torch.Tensor, got an object of type: { type (value )} "
351
+ )
357
352
358
- supported_layer_types = ["CONVOLUTION" , "DECONVOLUTION" ]
359
- supported_weight_types = ["KERNEL" ]
360
- assert (
361
- layer_type_name in supported_layer_types
362
- ), f"Unsupported layer type: { layer_type_name } . Please add the layer type to this function to enable refitting."
353
+ # Weight Recording
354
+ supported_layer_types = ["CONVOLUTION" , "DECONVOLUTION" , "CONSTANT" ]
355
+ supported_weight_types = ["KERNEL" , "BIAS" , "CONSTANT" ]
356
+ assert (
357
+ layer_type_name in supported_layer_types
358
+ ), f"Unsupported layer type: { layer_type_name } . Please add the layer type to this function to enable refitting."
359
+ assert (
360
+ weight_type_name in supported_weight_types
361
+ ), f"Unsupported weight type: { weight_type_name } . Please add the weight type to this function to enable refitting."
362
+
363
+ if weight_type_name == "CONSTANT" and layer_type_name == "CONSTANT" :
364
+ weight_name = f"{ name } CONSTANT"
365
+ ctx .record_weight (weight_name , value )
366
+
367
+ else :
363
368
assert (
364
- weight_type_name in supported_weight_types
365
- ), f"Unsupported weight type: { weight_type_name } . Please add the weight type to this function to enable refitting. "
369
+ target is not None
370
+ ), "target must be provided if the weight type and layer type is not CONSTANT "
366
371
source_ir = source_ir if source_ir is not None else SourceIR .UNKNOWN
367
372
target_name = (
368
373
f"{ source_ir } _ops.{ target } "
369
374
if isinstance (target , str )
370
375
else f"{ source_ir } _ops.{ target .__name__ } "
371
376
)
372
377
373
- name = f"[{ layer_type_name } ]-[{ target_name } ]-[{ name } ] { weight_type_name } "
374
- record_weight_in_ctx (ctx , name , value )
375
-
376
- if isinstance (value , torch .Tensor ):
377
- # Tensor must be contiguous before conversion
378
- value = value .contiguous ()
379
- value_trt_dtype = _enums .dtype ._from (value .dtype ).to (trt .DataType )
380
- return trt .Weights (value_trt_dtype , value .data_ptr (), value .nelement ())
381
- elif isinstance (value , np .ndarray ):
382
- value = np .ascontiguousarray (value )
383
- value_np_dtype = _enums .dtype ._from (value .dtype ).to (np .dtype , use_default = True )
384
- return trt .Weights (value_np_dtype , value .data , value .size )
385
- else :
386
- raise AssertionError (
387
- f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: { type (value )} "
388
- )
378
+ weight_name = f"[{ layer_type_name } ]-[{ target_name } ]-[{ name } ] { weight_type_name } "
379
+ ctx .record_weight (weight_name , value )
389
380
381
+ # TRT Weights Creation
390
382
391
- def record_weight_in_ctx (
392
- ctx : ConversionContext ,
393
- name : str ,
394
- value : torch .Tensor ,
395
- ) -> None :
396
- ctx .weight_refit_map [name ] = value
397
- ctx .cpu_weights_reference_holder [name ] = value
383
+ # Tensor must be contiguous before conversion
384
+ value = value .contiguous ()
385
+ if dtype is None :
386
+ dtype = _enums .dtype ._from (value .dtype ).to (trt .DataType )
387
+
388
+ if count is None :
389
+ count = value .nelement ()
390
+
391
+ return trt .Weights (dtype , value .data_ptr (), count )
398
392
399
393
400
394
def create_constant (
@@ -451,24 +445,26 @@ def create_constant(
451
445
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
452
446
)
453
447
shape [- 1 ] = shape [- 1 ] * 2
454
- weights = trt .Weights (
455
- type = trt .DataType .FP4 ,
456
- ptr = torch_value .data_ptr (),
448
+ weights = to_trt_weights (
449
+ ctx ,
450
+ torch_value ,
451
+ name ,
452
+ "CONSTANT" ,
453
+ "CONSTANT" ,
454
+ dtype = trt .DataType .FP4 ,
457
455
count = torch_value .numel () * 2 ,
458
456
)
459
457
constant = ctx .net .add_constant (
460
458
shape ,
461
459
weights ,
462
460
)
463
461
constant .name = name
464
- record_weight_in_ctx (ctx , name + " FP4_CONSTANT" , torch_value )
465
462
return constant .get_output (0 )
466
463
467
464
# Record the weight in ctx for refit and cpu memory reference
468
- record_weight_in_ctx (ctx , name + " CONSTANT" , torch_value )
469
465
470
466
# Convert the torch.Tensor to a trt.Weights object
471
- trt_weights = to_trt_weights (torch_value , record_weight = False )
467
+ trt_weights = to_trt_weights (ctx , torch_value , name , "CONSTANT" , "CONSTANT" )
472
468
constant = ctx .net .add_constant (
473
469
shape ,
474
470
trt_weights ,
0 commit comments