From 53c1bcbae90032b7f40248a1e9fc7911957f565b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 17 Jun 2025 21:56:44 +0000 Subject: [PATCH 1/3] Delete one copy --- py/torch_tensorrt/dynamo/_compiler.py | 3 ++- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 3 +-- py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d7092f1e0f..3e75ec3a00 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -707,7 +707,8 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(gm, delete_module=False) + # deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b134b3d5f5..2897f829fc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -498,8 +498,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 928b7284fe..a3e5734715 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] From 052305477430d469201e6862eec13d3b46df84ec Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 18 Jun 2025 22:11:19 +0000 Subject: [PATCH 2/3] Added an example that can compile on A40 with this PR but cannot under main --- examples/apps/flux_demo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index e222b4f772..061b45f327 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -44,6 +44,10 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) + pipe.transformer = FluxTransformer2DModel( + num_layers=23, num_single_layers=10, guidance_embeds=True + ).to(torch.float16) + if args.low_vram_mode: pipe.enable_model_cpu_offload() else: From f147421d2f00b9a9865900d35d18e07cc13903f6 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 24 Jun 2025 22:22:46 +0000 Subject: [PATCH 3/3] Commented out for NVBug people to debug --- examples/apps/flux_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 061b45f327..df4c35912a 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -44,9 +44,9 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) - pipe.transformer = FluxTransformer2DModel( - num_layers=23, num_single_layers=10, guidance_embeds=True - ).to(torch.float16) + # pipe.transformer = FluxTransformer2DModel( + # num_layers=28, num_single_layers=12, guidance_embeds=True + # ).to(torch.float16) if args.low_vram_mode: pipe.enable_model_cpu_offload()