Skip to content

Commit 5a6e386

Browse files
authored
[docs] Quantization + torch.compile + offloading (#11703)
* draft * feedback * update * feedback * fix * feedback * feedback * fix * feedback
1 parent 42077e6 commit 5a6e386

File tree

3 files changed

+227
-17
lines changed

3 files changed

+227
-17
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
title: Caching
181181
- local: optimization/memory
182182
title: Reduce memory usage
183+
- local: optimization/speed-memory-optims
184+
title: Compile and offloading quantized models
183185
- local: optimization/pruna
184186
title: Pruna
185187
- local: optimization/xformers

docs/source/en/optimization/memory.md

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipe
1717
This guide will show you how to reduce your memory usage.
1818

1919
> [!TIP]
20-
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
20+
> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.
2121
2222
## Multiple GPUs
2323

@@ -145,7 +145,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
145145
```
146146

147147
> [!WARNING]
148-
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
148+
> The [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] classes don't support slicing.
149149
150150
## VAE tiling
151151

@@ -172,7 +172,13 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
172172
> [!WARNING]
173173
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
174174
175-
## CPU offloading
175+
## Offloading
176+
177+
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
178+
179+
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
180+
181+
### CPU offloading
176182

177183
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
178184

@@ -203,7 +209,7 @@ pipeline(
203209
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
204210
```
205211

206-
## Model offloading
212+
### Model offloading
207213

208214
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
209215

@@ -219,7 +225,7 @@ from diffusers import DiffusionPipeline
219225
pipeline = DiffusionPipeline.from_pretrained(
220226
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
221227
)
222-
pipline.enable_model_cpu_offload()
228+
pipeline.enable_model_cpu_offload()
223229

224230
pipeline(
225231
prompt="An astronaut riding a horse on Mars",
@@ -234,7 +240,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
234240

235241
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
236242

237-
## Group offloading
243+
### Group offloading
238244

239245
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
240246

@@ -278,7 +284,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
278284
export_to_video(video, "output.mp4", fps=8)
279285
```
280286

281-
### CUDA stream
287+
#### CUDA stream
282288

283289
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
284290

@@ -295,22 +301,25 @@ pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_d
295301

296302
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
297303

298-
<Tip>
304+
#### Offloading to disk
305+
306+
Group offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.
299307

300-
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
301-
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
308+
Set the `offload_to_disk_path` argument in either [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`] to offload the model to the disk.
302309

303-
</Tip>
310+
```py
311+
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", offload_to_disk_path="path/to/disk")
304312

305-
### Offloading to disk
313+
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2, offload_to_disk_path="path/to/disk")
314+
```
306315

307-
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
308-
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
309-
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
310-
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
316+
Refer to these [two](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) [tables](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) to compare the speed and memory trade-offs.
311317

312318
## Layerwise casting
313319

320+
> [!TIP]
321+
> Combine layerwise casting with [group offloading](#group-offloading) for even more memory savings.
322+
314323
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
315324

316325
> [!WARNING]
@@ -500,7 +509,7 @@ with torch.inference_mode():
500509
## Memory-efficient attention
501510

502511
> [!TIP]
503-
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention!
512+
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
504513
505514
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
506515

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Compile and offloading quantized models
14+
15+
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
16+
17+
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
18+
19+
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
20+
21+
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
22+
23+
| combination | latency (s) | memory-usage (GB) |
24+
|---|---|---|
25+
| quantization | 32.602 | 14.9453 |
26+
| quantization, torch.compile | 25.847 | 14.9448 |
27+
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
28+
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
29+
30+
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
31+
32+
```bash
33+
pip install -U bitsandbytes
34+
```
35+
36+
## Quantization and torch.compile
37+
38+
Start by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.
39+
40+
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
41+
42+
```py
43+
import torch
44+
from diffusers import DiffusionPipeline
45+
from diffusers.quantizers import PipelineQuantizationConfig
46+
47+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
48+
49+
# quantize
50+
pipeline_quant_config = PipelineQuantizationConfig(
51+
quant_backend="bitsandbytes_4bit",
52+
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
53+
components_to_quantize=["transformer", "text_encoder_2"],
54+
)
55+
pipeline = DiffusionPipeline.from_pretrained(
56+
"black-forest-labs/FLUX.1-dev",
57+
quantization_config=pipeline_quant_config,
58+
torch_dtype=torch.bfloat16,
59+
).to("cuda")
60+
61+
# compile
62+
pipeline.transformer.to(memory_format=torch.channels_last)
63+
pipeline.transformer.compile(mode="max-autotune", fullgraph=True)
64+
pipeline("""
65+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
66+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
67+
"""
68+
).images[0]
69+
```
70+
71+
## Quantization, torch.compile, and offloading
72+
73+
In addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.
74+
75+
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` during offloading to avoid excessive recompilation and set `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
76+
77+
<hfoptions id="offloading">
78+
<hfoption id="model CPU offloading">
79+
80+
[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.
81+
82+
```py
83+
import torch
84+
from diffusers import DiffusionPipeline
85+
from diffusers.quantizers import PipelineQuantizationConfig
86+
87+
torch._dynamo.config.cache_size_limit = 1000
88+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
89+
90+
# quantize
91+
pipeline_quant_config = PipelineQuantizationConfig(
92+
quant_backend="bitsandbytes_4bit",
93+
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
94+
components_to_quantize=["transformer", "text_encoder_2"],
95+
)
96+
pipeline = DiffusionPipeline.from_pretrained(
97+
"black-forest-labs/FLUX.1-dev",
98+
quantization_config=pipeline_quant_config,
99+
torch_dtype=torch.bfloat16,
100+
).to("cuda")
101+
102+
# model CPU offloading
103+
pipeline.enable_model_cpu_offload()
104+
105+
# compile
106+
pipeline.transformer.compile()
107+
pipeline(
108+
"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
109+
).images[0]
110+
```
111+
112+
</hfoption>
113+
<hfoption id="group offloading">
114+
115+
[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.
116+
117+
By overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory.
118+
119+
```py
120+
# pip install ftfy
121+
import torch
122+
from diffusers import AutoModel, DiffusionPipeline
123+
from diffusers.hooks import apply_group_offloading
124+
from diffusers.utils import export_to_video
125+
from diffusers.quantizers import PipelineQuantizationConfig
126+
from transformers import UMT5EncoderModel
127+
128+
torch._dynamo.config.cache_size_limit = 1000
129+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
130+
131+
# quantize
132+
pipeline_quant_config = PipelineQuantizationConfig(
133+
quant_backend="bitsandbytes_4bit",
134+
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
135+
components_to_quantize=["transformer", "text_encoder"],
136+
)
137+
138+
text_encoder = UMT5EncoderModel.from_pretrained(
139+
"Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
140+
)
141+
pipeline = DiffusionPipeline.from_pretrained(
142+
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
143+
quantization_config=pipeline_quant_config,
144+
torch_dtype=torch.bfloat16,
145+
).to("cuda")
146+
147+
# group offloading
148+
onload_device = torch.device("cuda")
149+
offload_device = torch.device("cpu")
150+
151+
pipeline.transformer.enable_group_offload(
152+
onload_device=onload_device,
153+
offload_device=offload_device,
154+
offload_type="leaf_level",
155+
use_stream=True,
156+
non_blocking=True
157+
)
158+
pipeline.vae.enable_group_offload(
159+
onload_device=onload_device,
160+
offload_device=offload_device,
161+
offload_type="leaf_level",
162+
use_stream=True,
163+
non_blocking=True
164+
)
165+
apply_group_offloading(
166+
pipeline.text_encoder,
167+
onload_device=onload_device,
168+
offload_type="leaf_level",
169+
use_stream=True,
170+
non_blocking=True
171+
)
172+
173+
# compile
174+
pipeline.transformer.compile()
175+
176+
prompt = """
177+
The camera rushes from far to near in a low-angle shot,
178+
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
179+
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
180+
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
181+
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
182+
"""
183+
negative_prompt = """
184+
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
185+
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
186+
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
187+
"""
188+
189+
output = pipeline(
190+
prompt=prompt,
191+
negative_prompt=negative_prompt,
192+
num_frames=81,
193+
guidance_scale=5.0,
194+
).frames[0]
195+
export_to_video(output, "output.mp4", fps=16)
196+
```
197+
198+
</hfoption>
199+
</hfoptions>

0 commit comments

Comments
 (0)