diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index e222b4f772..061b45f327 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -44,6 +44,10 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) + pipe.transformer = FluxTransformer2DModel( + num_layers=23, num_single_layers=10, guidance_embeds=True + ).to(torch.float16) + if args.low_vram_mode: pipe.enable_model_cpu_offload() else: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d7092f1e0f..3e75ec3a00 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -707,7 +707,8 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(gm, delete_module=False) + # deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b134b3d5f5..2897f829fc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -498,8 +498,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 928b7284fe..a3e5734715 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = []