Skip to content

Commit 53c1bcb

Browse files
committed
Delete one copy
1 parent 0273726 commit 53c1bcb

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,8 @@ def compile(
707707

708708
# Move the weights in the state_dict to CPU
709709
if offload_module_to_cpu:
710-
deallocate_module(exported_program.module(), delete_module=False)
710+
deallocate_module(gm, delete_module=False)
711+
# deallocate_module(exported_program.module(), delete_module=False)
711712
logger.info(
712713
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
713714
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,7 @@ def _save_weight_mapping(self) -> None:
498498
_LOGGER.info("Building weight name mapping...")
499499
# Stage 1: Name mapping
500500
torch_device = to_torch_device(self.compilation_settings.device)
501-
self.module.to(torch_device)
502-
sd = self.module.state_dict()
501+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
503502
weight_name_map: dict[str, Any] = {}
504503
weight_refit_map = self.ctx.weight_refit_map
505504
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

0 commit comments

Comments
 (0)