|
14 | 14 | 1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 |
15 | 15 | 2. Save a Mutable Torch TensorRT Module |
16 | 16 | 3. Integration with Huggingface pipeline in LoRA use case |
| 17 | +4. Usage of dynamic shape with Mutable Torch TensorRT Module |
17 | 18 | """ |
18 | 19 |
|
19 | 20 | import numpy as np |
|
25 | 26 | torch.manual_seed(5) |
26 | 27 | inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] |
27 | 28 |
|
28 | | -# %% |
29 | | -# Initialize the Mutable Torch TensorRT Module with settings. |
30 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
31 | | -settings = { |
32 | | - "use_python": False, |
33 | | - "enabled_precisions": {torch.float32}, |
34 | | - "immutable_weights": False, |
35 | | -} |
| 29 | +# # %% |
| 30 | +# # Initialize the Mutable Torch TensorRT Module with settings. |
| 31 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 32 | +# settings = { |
| 33 | +# "use_python": False, |
| 34 | +# "enabled_precisions": {torch.float32}, |
| 35 | +# "immutable_weights": False, |
| 36 | +# } |
36 | 37 |
|
37 | | -model = models.resnet18(pretrained=True).eval().to("cuda") |
38 | | -mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
39 | | -# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
40 | | -mutable_module(*inputs) |
| 38 | +# model = models.resnet18(pretrained=True).eval().to("cuda") |
| 39 | +# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
| 40 | +# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
| 41 | +# mutable_module(*inputs) |
41 | 42 |
|
42 | | -# %% |
43 | | -# Make modifications to the mutable module. |
44 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 43 | +# # %% |
| 44 | +# # Make modifications to the mutable module. |
| 45 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
45 | 46 |
|
46 | | -# %% |
47 | | -# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
48 | | -model2 = models.resnet18(pretrained=False).eval().to("cuda") |
49 | | -mutable_module.load_state_dict(model2.state_dict()) |
| 47 | +# # %% |
| 48 | +# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
| 49 | +# model2 = models.resnet18(pretrained=False).eval().to("cuda") |
| 50 | +# mutable_module.load_state_dict(model2.state_dict()) |
50 | 51 |
|
51 | 52 |
|
52 | | -# Check the output |
53 | | -# The refit happens while you call the mutable module again. |
54 | | -expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
55 | | -for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
56 | | - assert torch.allclose( |
57 | | - expected_output, refitted_output, 1e-2, 1e-2 |
58 | | - ), "Refit Result is not correct. Refit failed" |
| 53 | +# # Check the output |
| 54 | +# # The refit happens while you call the mutable module again. |
| 55 | +# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
| 56 | +# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
| 57 | +# assert torch.allclose( |
| 58 | +# expected_output, refitted_output, 1e-2, 1e-2 |
| 59 | +# ), "Refit Result is not correct. Refit failed" |
59 | 60 |
|
60 | | -print("Refit successfully!") |
| 61 | +# print("Refit successfully!") |
61 | 62 |
|
62 | | -# %% |
63 | | -# Saving Mutable Torch TensorRT Module |
64 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 63 | +# # %% |
| 64 | +# # Saving Mutable Torch TensorRT Module |
| 65 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
65 | 66 |
|
66 | | -# Currently, saving is only enabled for C++ runtime, not python runtime. |
67 | | -torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
68 | | -reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
| 67 | +# # Currently, saving is only when "use_python" = False in settings |
| 68 | +# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
| 69 | +# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
69 | 70 |
|
70 | | -# %% |
71 | | -# Stable Diffusion with Huggingface |
72 | | -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 71 | +# # %% |
| 72 | +# # Stable Diffusion with Huggingface |
| 73 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
73 | 74 |
|
74 | | -# The LoRA checkpoint is from https://civitai.com/models/12597/moxin |
| 75 | +# # The LoRA checkpoint is from https://civitai.com/models/12597/moxin |
75 | 76 |
|
76 | | -from diffusers import DiffusionPipeline |
| 77 | +# from diffusers import DiffusionPipeline |
77 | 78 |
|
78 | | -with torch.no_grad(): |
79 | | - settings = { |
80 | | - "use_python_runtime": True, |
81 | | - "enabled_precisions": {torch.float16}, |
82 | | - "debug": True, |
83 | | - "immutable_weights": False, |
84 | | - } |
| 79 | +# with torch.no_grad(): |
| 80 | +# settings = { |
| 81 | +# "use_python_runtime": True, |
| 82 | +# "enabled_precisions": {torch.float16}, |
| 83 | +# "debug": True, |
| 84 | +# "immutable_weights": False, |
| 85 | +# } |
85 | 86 |
|
86 | | - model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
87 | | - device = "cuda:0" |
| 87 | +# model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
| 88 | +# device = "cuda:0" |
88 | 89 |
|
89 | | - prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
90 | | - negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
| 90 | +# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
| 91 | +# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
91 | 92 |
|
92 | | - pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
93 | | - pipe.to(device) |
| 93 | +# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| 94 | +# pipe.to(device) |
94 | 95 |
|
95 | | - # The only extra line you need |
96 | | - pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
| 96 | +# # The only extra line you need |
| 97 | +# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
97 | 98 |
|
98 | | - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
99 | | - image.save("./without_LoRA_mutable.jpg") |
| 99 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 100 | +# image.save("./without_LoRA_mutable.jpg") |
100 | 101 |
|
101 | | - # Standard Huggingface LoRA loading procedure |
102 | | - pipe.load_lora_weights( |
103 | | - "stablediffusionapi/load_lora_embeddings", |
104 | | - weight_name="all-disney-princess-xl-lo.safetensors", |
105 | | - adapter_name="lora1", |
106 | | - ) |
107 | | - pipe.set_adapters(["lora1"], adapter_weights=[1]) |
108 | | - pipe.fuse_lora() |
109 | | - pipe.unload_lora_weights() |
| 102 | +# # Standard Huggingface LoRA loading procedure |
| 103 | +# pipe.load_lora_weights( |
| 104 | +# "stablediffusionapi/load_lora_embeddings", |
| 105 | +# weight_name="all-disney-princess-xl-lo.safetensors", |
| 106 | +# adapter_name="lora1", |
| 107 | +# ) |
| 108 | +# pipe.set_adapters(["lora1"], adapter_weights=[1]) |
| 109 | +# pipe.fuse_lora() |
| 110 | +# pipe.unload_lora_weights() |
110 | 111 |
|
111 | | - # Refit triggered |
112 | | - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
113 | | - image.save("./with_LoRA_mutable.jpg") |
| 112 | +# # Refit triggered |
| 113 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 114 | +# image.save("./with_LoRA_mutable.jpg") |
| 115 | + |
| 116 | + |
| 117 | +# %% |
| 118 | +# Use Mutable Torch TensorRT module with dynamic shape |
| 119 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 120 | +class Model(torch.nn.Module): |
| 121 | + def __init__(self): |
| 122 | + super().__init__() |
| 123 | + |
| 124 | + def forward(self, a, b, c={}): |
| 125 | + x = torch.matmul(a, b) |
| 126 | + x = torch.matmul(c["a"], c["b"].T) |
| 127 | + x = 2 * c["b"][1] |
| 128 | + return x |
| 129 | + |
| 130 | + |
| 131 | +model = Model().eval().cuda() |
| 132 | +inputs = (torch.rand(10, 3), torch.rand(3, 30)) |
| 133 | +kwargs = { |
| 134 | + "c": {"a": torch.rand(10, 30), "b": torch.rand(10, 30)}, |
| 135 | +} |
| 136 | + |
| 137 | +dim = torch.export.Dim("dim", min=1, max=50) |
| 138 | +dim2 = torch.export.Dim("dim2", min=1, max=50) |
| 139 | +args_dynamic_shapes = ({1: dim}, {0: dim}) |
| 140 | +kwarg_dynamic_shapes = { |
| 141 | + "c": {"a": {}, "b": {0: dim2}}, |
| 142 | +} |
| 143 | +# Export the model first with custom dynamic shape constraints |
| 144 | +# exp_program = torch.export.export(model, tuple(inputs), kwargs=k |
| 145 | +trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True) |
| 146 | +trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) |
| 147 | +# Run inference |
| 148 | +trt_gm(*inputs, **kwargs) |
0 commit comments