@@ -93,8 +93,8 @@ def construct_refit_mapping_from_weight_name_map(
9393 # If weights is not in sd, we can leave it unchanged
9494 continue
9595 else :
96- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
97- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
96+ trt_dtype = dtype ._from (np_weight_type ).to (trt .DataType )
97+ torch_dtype = dtype ._from (np_weight_type ).to (torch .dtype )
9898 engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ].to (
9999 to_torch_device (settings .device )
100100 )
@@ -148,8 +148,8 @@ def _refit_single_trt_engine_with_gm(
148148 for constant_name , val in constant_mapping .items ():
149149 np_weight_type = val .dtype
150150 val_tensor = torch .from_numpy (val ).cuda ()
151- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
152- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
151+ trt_dtype = dtype ._from (np_weight_type ).to (trt .DataType )
152+ torch_dtype = dtype ._from (np_weight_type ).to (torch .dtype )
153153 constant_mapping_with_type [constant_name ] = (
154154 val_tensor .clone ().reshape (- 1 ).contiguous ().to (torch_dtype ),
155155 trt_dtype ,
@@ -179,7 +179,7 @@ def _refit_single_trt_engine_with_gm(
179179 raise AssertionError (f"{ layer_name } is not found in weight mapping" )
180180 # Use Numpy to create weights
181181 weight = mapping [layer_name ]
182- trt_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
182+ trt_dtype = dtype ._from (weight .dtype ).to (trt .DataType )
183183 trt_wt_tensor = trt .Weights (trt_dtype , weight .ctypes .data , weight .size )
184184 refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
185185 refitted .add (layer_name )
0 commit comments