@@ -323,6 +323,13 @@ def _construct_trt_network_def(self) -> None:
323
323
)
324
324
325
325
def _save_weight_mapping (self ) -> None :
326
+ """
327
+ Construct the weight name mapping from engine weight name to state_dict weight name.
328
+ Cache the weight name for future refitting usecases.
329
+ Two-stage weight name tracing:
330
+ 1. Name transformation from engine weight name to state_dict weight name
331
+ 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict
332
+ """
326
333
327
334
def find_weight (
328
335
weight_name : str , np_map : dict [str , Any ], sd : dict [str , Any ]
@@ -386,7 +393,7 @@ def check_weight_equal(
386
393
)
387
394
}
388
395
"""
389
-
396
+ # Stage 1: Name mapping
390
397
sd = self .module .state_dict ()
391
398
weight_name_map : dict [str , Any ] = {}
392
399
np_map = {}
@@ -413,6 +420,7 @@ def check_weight_equal(
413
420
[i for i in sd_weight_name_list [:- 1 ] if i ]
414
421
)
415
422
suffix = sd_weight_name_list [- 1 ]
423
+ # Retrieve each weight name(s) in state_dict
416
424
if layer_type == "CONSTANT" :
417
425
if "embedding" in suffix :
418
426
sd_weight_name = f"{ sd_weight_name } .{ torch_attr [0 ]} "
@@ -430,7 +438,7 @@ def check_weight_equal(
430
438
weight_name_map [engine_weight_name ] = sd_weight_name
431
439
np_map [engine_weight_name ] = weight
432
440
433
- # Value mapping
441
+ # Stage 2: Value mapping
434
442
for engine_weight_name , sd_weight_name in weight_name_map .items ():
435
443
if "SCALE" in engine_weight_name :
436
444
# There is no direct connection in batch_norm layer. So skip it
@@ -448,7 +456,6 @@ def check_weight_equal(
448
456
]
449
457
450
458
self .weight_name_map = weight_name_map
451
- # check = {k:(weight_name_map[k], np_map[k]) for k, v in np_map.items()}
452
459
453
460
def run (
454
461
self ,
0 commit comments