Skip to content

Commit 977162c

Browse files
authored
Adds a document on token merging (#3208)
* add document on token merging. * fix headline. * fix: headline. * add some samples for comparison.
1 parent 744663f commit 977162c

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
title: MPS
106106
- local: optimization/habana
107107
title: Habana Gaudi
108+
- local: optimization/tome
109+
title: Token Merging
108110
title: Optimization/Special Hardware
109111
- sections:
110112
- local: conceptual/philosophy

docs/source/en/optimization/tome.mdx

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Token Merging
14+
15+
Token Merging (introduced in [Token Merging: Your ViT But Faster](https://arxiv.org/abs/2210.09461)) works by merging the redundant tokens / patches progressively in the forward pass of a Transformer-based network. It can speed up the inference latency of the underlying network.
16+
17+
After Token Merging (ToMe) was released, the authors released [Token Merging for Fast Stable Diffusion](https://arxiv.org/abs/2303.17604), which introduced a version of ToMe which is more compatible with Stable Diffusion. We can use ToMe to gracefully speed up the inference latency of a [`DiffusionPipeline`]. This doc discusses how to apply ToMe to the [`StableDiffusionPipeline`], the expected speedups, and the qualitative aspects of using ToMe on the [`StableDiffusionPipeline`].
18+
19+
## Using ToMe
20+
21+
The authors of ToMe released a convenient Python library called [`tomesd`](https://github.com/dbolya/tomesd) that lets us apply ToMe to a [`DiffusionPipeline`] like so:
22+
23+
```diff
24+
from diffusers import StableDiffusionPipeline
25+
import tomesd
26+
27+
pipeline = StableDiffusionPipeline.from_pretrained(
28+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
29+
).to("cuda")
30+
+ tomesd.apply_patch(pipeline, ratio=0.5)
31+
32+
image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
33+
```
34+
35+
And that’s it!
36+
37+
`tomesd.apply_patch()` exposes [a number of arguments](https://github.com/dbolya/tomesd#usage) to let us strike a balance between the pipeline inference speed and the quality of the generated tokens. Amongst those arguments, the most important one is `ratio`. `ratio` controls the number of tokens that will be merged during the forward pass. For more details on `tomesd`, please refer to the original repository https://github.com/dbolya/tomesd and [the paper](https://arxiv.org/abs/2303.17604).
38+
39+
## Benchmarking `tomesd` with `StableDiffusionPipeline`
40+
41+
We benchmarked the impact of using `tomesd` on [`StableDiffusionPipeline`] along with [xformers](https://huggingface.co/docs/diffusers/optimization/xformers) across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):
42+
43+
```bash
44+
- `diffusers` version: 0.15.1
45+
- Python version: 3.8.16
46+
- PyTorch version (GPU?): 1.13.1+cu116 (True)
47+
- Huggingface_hub version: 0.13.2
48+
- Transformers version: 4.27.2
49+
- Accelerate version: 0.18.0
50+
- xFormers version: 0.0.16
51+
- tomesd version: 0.1.2
52+
```
53+
54+
We used this script for benchmarking: [https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). Following are our findings:
55+
56+
### A100
57+
58+
| Resolution | Batch size | Vanilla | ToMe | ToMe + xFormers | ToMe speedup (%) | ToMe + xFormers speedup (%) |
59+
| --- | --- | --- | --- | --- | --- | --- |
60+
| 512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |
61+
| | | | | | | |
62+
| 768 | 10 | OOM | 14.71 | 11 | | |
63+
| | 8 | OOM | 11.56 | 8.84 | | |
64+
| | 4 | OOM | 5.98 | 4.66 | | |
65+
| | 2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 |
66+
| | 1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 |
67+
| | | | | | | |
68+
| 1024 | 10 | OOM | OOM | OOM | | |
69+
| | 8 | OOM | OOM | OOM | | |
70+
| | 4 | OOM | 12.51 | 9.09 | | |
71+
| | 2 | OOM | 6.52 | 4.96 | | |
72+
| | 1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |
73+
74+
***The timings reported here are in seconds. Speedups are calculated over the `Vanilla` timings.***
75+
76+
### V100
77+
78+
| Resolution | Batch size | Vanilla | ToMe | ToMe + xFormers | ToMe speedup (%) | ToMe + xFormers speedup (%) |
79+
| --- | --- | --- | --- | --- | --- | --- |
80+
| 512 | 10 | OOM | 10.03 | 9.29 | | |
81+
| | 8 | OOM | 8.05 | 7.47 | | |
82+
| | 4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 |
83+
| | 2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 |
84+
| | 1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 |
85+
| | | | | | | |
86+
| 768 | 10 | OOM | OOM | 23.67 | | |
87+
| | 8 | OOM | OOM | 18.81 | | |
88+
| | 4 | OOM | 11.81 | 9.7 | | |
89+
| | 2 | OOM | 6.27 | 5.2 | | |
90+
| | 1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 |
91+
| | | | | | | |
92+
| 1024 | 10 | OOM | OOM | OOM | | |
93+
| | 8 | OOM | OOM | OOM | | |
94+
| | 4 | OOM | OOM | 19.35 | | |
95+
| | 2 | OOM | 13 | 10.78 | | |
96+
| | 1 | OOM | 6.66 | 5.54 | | |
97+
98+
As seen in the tables above, the speedup with `tomesd` becomes more pronounced for larger image resolutions. It is also interesting to note that with `tomesd`, it becomes possible to run the pipeline on a higher resolution, like 1024x1024.
99+
100+
It might be possible to speed up inference even further with [`torch.compile()`](https://huggingface.co/docs/diffusers/optimization/torch2.0).
101+
102+
## Quality
103+
104+
As reported in [the paper](https://arxiv.org/abs/2303.17604), ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the `ratio`, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.
105+
106+
To test the quality of the generated samples using our setup, we sampled a few prompts from the “Parti Prompts” (introduced in [Parti](https://parti.research.google/)) and performed inference with the [`StableDiffusionPipeline`] in the following settings:
107+
108+
- Vanilla [`StableDiffusionPipeline`]
109+
- [`StableDiffusionPipeline`] + ToMe
110+
- [`StableDiffusionPipeline`] + ToMe + xformers
111+
112+
We didn’t notice any significant decrease in the quality of the generated samples. Here are samples:
113+
114+
![tome-samples](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png)
115+
116+
You can check out the generated samples [here](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=). We used [this script](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd) for conducting this experiment.

0 commit comments

Comments
 (0)