Skip to content

Commit 227deee

Browse files
committed
Added comments and fixed some issue
1 parent 1cdd3c0 commit 227deee

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ def _construct_trt_network_def(self) -> None:
323323
)
324324

325325
def _save_weight_mapping(self) -> None:
326+
"""
327+
Construct the weight name mapping from engine weight name to state_dict weight name.
328+
Cache the weight name for future refitting usecases.
329+
Two-stage weight name tracing:
330+
1. Name transformation from engine weight name to state_dict weight name
331+
2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict
332+
"""
326333

327334
def find_weight(
328335
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
@@ -386,7 +393,7 @@ def check_weight_equal(
386393
)
387394
}
388395
"""
389-
396+
# Stage 1: Name mapping
390397
sd = self.module.state_dict()
391398
weight_name_map: dict[str, Any] = {}
392399
np_map = {}
@@ -413,6 +420,7 @@ def check_weight_equal(
413420
[i for i in sd_weight_name_list[:-1] if i]
414421
)
415422
suffix = sd_weight_name_list[-1]
423+
# Retrieve each weight name(s) in state_dict
416424
if layer_type == "CONSTANT":
417425
if "embedding" in suffix:
418426
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
@@ -430,7 +438,7 @@ def check_weight_equal(
430438
weight_name_map[engine_weight_name] = sd_weight_name
431439
np_map[engine_weight_name] = weight
432440

433-
# Value mapping
441+
# Stage 2: Value mapping
434442
for engine_weight_name, sd_weight_name in weight_name_map.items():
435443
if "SCALE" in engine_weight_name:
436444
# There is no direct connection in batch_norm layer. So skip it
@@ -448,7 +456,6 @@ def check_weight_equal(
448456
]
449457

450458
self.weight_name_map = weight_name_map
451-
# check = {k:(weight_name_map[k], np_map[k]) for k, v in np_map.items()}
452459

453460
def run(
454461
self,

tests/py/dynamo/models/test_model_refit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_fast_refit_one_engine():
127127

128128

129129
@pytest.mark.unit
130-
def test_fast_refit_one_engin_no_map():
130+
def test_fast_refit_one_engine_no_map():
131131

132132
model = models.resnet18(pretrained=False).eval().to("cuda")
133133
model2 = models.resnet18(pretrained=True).eval().to("cuda")
@@ -174,7 +174,7 @@ def test_fast_refit_one_engin_no_map():
174174

175175

176176
@pytest.mark.unit
177-
def test_fast_refit_one_engin_wrong_map():
177+
def test_fast_refit_one_engine_wrong_map():
178178

179179
model = models.resnet18(pretrained=False).eval().to("cuda")
180180
model2 = models.resnet18(pretrained=True).eval().to("cuda")

0 commit comments

Comments
 (0)