@@ -62,26 +62,13 @@ def construct_refit_mapping(
6262 Returns:
6363 Mapping from weight name in TensorRT to actual weight value in np.ndarray
6464 """
65- MODULE_MAP = {
66- "SCALE" : (trt .IScaleLayer , [("scale" , "SCALE" ), ("shift" , "SHIFT" )]),
67- "CONVOLUTION" : (
68- trt .IConvolutionLayer ,
69- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
70- ),
71- "DECONVOLUTION" : (
72- trt .IDeconvolutionLayer ,
73- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
74- ),
75- "CONSTANT" : (trt .IConstantLayer , [("weights" , "CONSTANT" )]),
76- }
7765
7866 output_dtypes = infer_module_output_dtypes (
7967 module ,
8068 truncate_double = settings .truncate_double ,
8169 )
8270
8371 # Use Interpreter
84- weight_map = {}
8572 interpreter = TRTInterpreter (
8673 module ,
8774 inputs ,
@@ -90,24 +77,8 @@ def construct_refit_mapping(
9077 compilation_settings = settings ,
9178 )
9279 interpreter ._construct_trt_network_def ()
93- net = interpreter .ctx .net
94- for i in range (net .num_layers ):
95- layer = net [i ]
96- layer_type : str = layer .type .name
97- if layer_type in MODULE_MAP :
98- # Cast the parent class to child class to access attributes
99- # For example: ILayer does not have ILayer.kernel/ILayer.bias
100- # So we cast it to IConvolutionLayer and access the attributes
101- layer .__class__ = MODULE_MAP [layer_type ][0 ]
102- for weight_type , weight_name in MODULE_MAP [layer_type ][1 ]:
103- weight = layer .__getattribute__ (weight_type ).copy ()
104- weight_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
105- weight_map [f"{ layer .name } { weight_name } " ] = (
106- weight ,
107- weight_dtype ,
108- )
10980
110- return weight_map
81+ return interpreter . ctx . mapping
11182
11283
11384@needs_refit
@@ -118,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
11889) -> dict [Any , Any ]:
11990 engine_weight_map = {}
12091 for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
121- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
122- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
123-
12492 if sd_weight_name not in state_dict :
12593 # If weights is not in sd, we can leave it unchanged
12694 continue
12795 else :
96+ trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
97+ torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
12898 engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ].to (
12999 to_torch_device (settings .device )
130100 )
@@ -208,8 +178,9 @@ def _refit_single_trt_engine_with_gm(
208178 if layer_name not in mapping :
209179 raise AssertionError (f"{ layer_name } is not found in weight mapping" )
210180 # Use Numpy to create weights
211- weight , datatype = mapping [layer_name ]
212- trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
181+ weight = mapping [layer_name ]
182+ trt_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
183+ trt_wt_tensor = trt .Weights (trt_dtype , weight .ctypes .data , weight .size )
213184 refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
214185 refitted .add (layer_name )
215186
0 commit comments