Skip to content

Commit 85107c3

Browse files
committed
Delete one copy
1 parent 1c00f0f commit 85107c3

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ def compile(
702702

703703
# Move the weights in the state_dict to CPU
704704
if offload_module_to_cpu:
705-
deallocate_module(exported_program.module(), delete_module=False)
705+
deallocate_module(gm, delete_module=False)
706+
# deallocate_module(exported_program.module(), delete_module=False)
706707
logger.info(
707708
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
708709
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def _construct_trt_network_def(self) -> None:
400400
@staticmethod
401401
def find_weight(
402402
weight_name: str,
403-
np_map: dict[str, Any],
403+
weight_refit_map: dict[str, Any],
404404
state_dict: dict[str, Any],
405405
device: torch.device,
406406
) -> str:
@@ -413,7 +413,7 @@ def find_weight(
413413
state_dict: state of the graph module
414414
"""
415415
with unset_fake_temporarily():
416-
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
416+
network_weight = weight_refit_map[weight_name].to(device)
417417
for sd_w_name, sd_weight in state_dict.items():
418418
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
419419
del state_dict[sd_w_name]
@@ -427,8 +427,8 @@ def check_weight_equal(
427427
device: torch.device,
428428
) -> Any:
429429
with unset_fake_temporarily():
430-
if not isinstance(network_weight, torch.Tensor):
431-
network_weight = torch.from_numpy(network_weight).to(device)
430+
if network_weight.device != device:
431+
network_weight = network_weight.to(device)
432432
try:
433433
return sd_weight.shape == network_weight.shape and torch.all(
434434
torch.abs(sd_weight - network_weight) < 0.01
@@ -494,11 +494,10 @@ def _save_weight_mapping(self) -> None:
494494
_LOGGER.info("Building weight name mapping...")
495495
# Stage 1: Name mapping
496496
torch_device = to_torch_device(self.compilation_settings.device)
497-
self.module.to(torch_device)
498-
sd = self.module.state_dict()
497+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
499498
weight_name_map: dict[str, Any] = {}
500-
np_map = self.ctx.weight_refit_map
501-
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
499+
weight_refit_map = self.ctx.weight_refit_map
500+
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
502501
net = self.ctx.net
503502
for i in range(net.num_layers):
504503
layer = net[i]
@@ -540,7 +539,7 @@ def _save_weight_mapping(self) -> None:
540539
else:
541540
sd_weight_name = f"{sd_weight_name}.{torch_attr}"
542541

543-
if engine_weight_name in np_map:
542+
if engine_weight_name in weight_refit_map:
544543
weight_name_map[engine_weight_name] = sd_weight_name
545544

546545
# Stage 2: Value mapping
@@ -549,10 +548,10 @@ def _save_weight_mapping(self) -> None:
549548
# There is no direct connection in batch_norm layer. So skip it
550549
pass
551550
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
552-
sd[sd_weight_name], np_map[engine_weight_name], torch_device
551+
sd[sd_weight_name], weight_refit_map[engine_weight_name], torch_device
553552
):
554553
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
555-
engine_weight_name, np_map, sd, torch_device
554+
engine_weight_name, weight_refit_map, sd, torch_device
556555
)
557556
if (
558557
weight_name_map[engine_weight_name] != ""
@@ -563,12 +562,13 @@ def _save_weight_mapping(self) -> None:
563562

564563
weight_name_map[engine_weight_name] = [
565564
weight_name_map[engine_weight_name],
566-
np_map[engine_weight_name].dtype,
565+
weight_refit_map[engine_weight_name].dtype,
567566
]
568567

569568
weight_name_map["constant_mapping"] = constant_mapping
570569
self.weight_name_map = weight_name_map
571-
del np_map, sd
570+
571+
del weight_refit_map, sd
572572
gc.collect()
573573
torch.cuda.empty_cache()
574574

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

0 commit comments

Comments
 (0)