@@ -400,7 +400,7 @@ def _construct_trt_network_def(self) -> None:
400
400
@staticmethod
401
401
def find_weight (
402
402
weight_name : str ,
403
- np_map : dict [str , Any ],
403
+ weight_refit_map : dict [str , Any ],
404
404
state_dict : dict [str , Any ],
405
405
device : torch .device ,
406
406
) -> str :
@@ -413,7 +413,7 @@ def find_weight(
413
413
state_dict: state of the graph module
414
414
"""
415
415
with unset_fake_temporarily ():
416
- network_weight = torch . from_numpy ( np_map [weight_name ]) .to (device )
416
+ network_weight = weight_refit_map [weight_name ].to (device )
417
417
for sd_w_name , sd_weight in state_dict .items ():
418
418
if TRTInterpreter .check_weight_equal (sd_weight , network_weight , device ):
419
419
del state_dict [sd_w_name ]
@@ -427,8 +427,8 @@ def check_weight_equal(
427
427
device : torch .device ,
428
428
) -> Any :
429
429
with unset_fake_temporarily ():
430
- if not isinstance ( network_weight , torch . Tensor ) :
431
- network_weight = torch . from_numpy ( network_weight ) .to (device )
430
+ if network_weight . device != device :
431
+ network_weight = network_weight .to (device )
432
432
try :
433
433
return sd_weight .shape == network_weight .shape and torch .all (
434
434
torch .abs (sd_weight - network_weight ) < 0.01
@@ -494,11 +494,10 @@ def _save_weight_mapping(self) -> None:
494
494
_LOGGER .info ("Building weight name mapping..." )
495
495
# Stage 1: Name mapping
496
496
torch_device = to_torch_device (self .compilation_settings .device )
497
- self .module .to (torch_device )
498
- sd = self .module .state_dict ()
497
+ sd = {k : v .to (torch_device ) for k , v in self .module .state_dict ().items ()}
499
498
weight_name_map : dict [str , Any ] = {}
500
- np_map = self .ctx .weight_refit_map
501
- constant_mapping = {k : v for k , v in np_map .items () if v .size == 1 }
499
+ weight_refit_map = self .ctx .weight_refit_map
500
+ constant_mapping = {k : v for k , v in weight_refit_map .items () if v .size == 1 }
502
501
net = self .ctx .net
503
502
for i in range (net .num_layers ):
504
503
layer = net [i ]
@@ -540,7 +539,7 @@ def _save_weight_mapping(self) -> None:
540
539
else :
541
540
sd_weight_name = f"{ sd_weight_name } .{ torch_attr } "
542
541
543
- if engine_weight_name in np_map :
542
+ if engine_weight_name in weight_refit_map :
544
543
weight_name_map [engine_weight_name ] = sd_weight_name
545
544
546
545
# Stage 2: Value mapping
@@ -549,10 +548,10 @@ def _save_weight_mapping(self) -> None:
549
548
# There is no direct connection in batch_norm layer. So skip it
550
549
pass
551
550
elif sd_weight_name not in sd or not TRTInterpreter .check_weight_equal (
552
- sd [sd_weight_name ], np_map [engine_weight_name ], torch_device
551
+ sd [sd_weight_name ], weight_refit_map [engine_weight_name ], torch_device
553
552
):
554
553
weight_name_map [engine_weight_name ] = TRTInterpreter .find_weight (
555
- engine_weight_name , np_map , sd , torch_device
554
+ engine_weight_name , weight_refit_map , sd , torch_device
556
555
)
557
556
if (
558
557
weight_name_map [engine_weight_name ] != ""
@@ -563,12 +562,13 @@ def _save_weight_mapping(self) -> None:
563
562
564
563
weight_name_map [engine_weight_name ] = [
565
564
weight_name_map [engine_weight_name ],
566
- np_map [engine_weight_name ].dtype ,
565
+ weight_refit_map [engine_weight_name ].dtype ,
567
566
]
568
567
569
568
weight_name_map ["constant_mapping" ] = constant_mapping
570
569
self .weight_name_map = weight_name_map
571
- del np_map , sd
570
+
571
+ del weight_refit_map , sd
572
572
gc .collect ()
573
573
torch .cuda .empty_cache ()
574
574
0 commit comments