Skip to content

Commit 51d970d

Browse files
[docs] add the Stable diffusion with Jax/Flax Guide into the docs (#2487)
* add stable diffusion jax guide --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent a937e1b commit 51d970d

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
title: How to contribute a Pipeline
5353
- local: using-diffusers/using_safetensors
5454
title: Using safetensors
55+
- local: using-diffusers/stable_diffusion_jax_how_to
56+
title: Stable Diffusion in JAX/Flax
5557
- local: using-diffusers/weighted_prompts
5658
title: Weighting Prompts
5759
title: Pipelines for Inference
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# 🧨 Stable Diffusion in JAX / Flax !
2+
3+
[[open-in-colab]]
4+
5+
🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
6+
7+
This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this notebook](https://huggingface.co/docs/diffusers/stable_diffusion).
8+
9+
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.
10+
11+
Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.
12+
13+
## Setup
14+
15+
First make sure diffusers is installed.
16+
17+
```bash
18+
!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
19+
!pip install diffusers
20+
```
21+
22+
```python
23+
import jax.tools.colab_tpu
24+
25+
jax.tools.colab_tpu.setup_tpu()
26+
import jax
27+
```
28+
29+
```python
30+
num_devices = jax.device_count()
31+
device_type = jax.devices()[0].device_kind
32+
33+
print(f"Found {num_devices} JAX devices of type {device_type}.")
34+
assert (
35+
"TPU" in device_type
36+
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
37+
```
38+
39+
```python out
40+
Found 8 JAX devices of type Cloud TPU.
41+
```
42+
43+
Then we import all the dependencies.
44+
45+
```python
46+
import numpy as np
47+
import jax
48+
import jax.numpy as jnp
49+
50+
from pathlib import Path
51+
from jax import pmap
52+
from flax.jax_utils import replicate
53+
from flax.training.common_utils import shard
54+
from PIL import Image
55+
56+
from huggingface_hub import notebook_login
57+
from diffusers import FlaxStableDiffusionPipeline
58+
```
59+
60+
## Model Loading
61+
62+
TPU devices support `bfloat16`, an efficient half-float type. We'll use it for our tests, but you can also use `float32` to use full precision instead.
63+
64+
```python
65+
dtype = jnp.bfloat16
66+
```
67+
68+
Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a `bf16` version of the weights, which leads to type warnings that you can safely ignore.
69+
70+
```python
71+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
72+
"CompVis/stable-diffusion-v1-4",
73+
revision="bf16",
74+
dtype=dtype,
75+
)
76+
```
77+
78+
## Inference
79+
80+
Since TPUs usually have 8 devices working in parallel, we'll replicate our prompt as many times as devices we have. Then we'll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we'll get 8 images in the same amount of time it takes for one chip to generate a single one.
81+
82+
After replicating the prompt, we obtain the tokenized text ids by invoking the `prepare_inputs` function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model.
83+
84+
```python
85+
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
86+
prompt = [prompt] * jax.device_count()
87+
prompt_ids = pipeline.prepare_inputs(prompt)
88+
prompt_ids.shape
89+
```
90+
91+
```python out
92+
(8, 77)
93+
```
94+
95+
### Replication and parallelization
96+
97+
Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using `flax.jax_utils.replicate`, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
98+
99+
```python
100+
p_params = replicate(params)
101+
```
102+
103+
```python
104+
prompt_ids = shard(prompt_ids)
105+
prompt_ids.shape
106+
```
107+
108+
```python out
109+
(8, 1, 77)
110+
```
111+
112+
That shape means that each one of the `8` devices will receive as an input a `jnp` array with shape `(1, 77)`. `1` is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than `1` if we wanted to generate multiple images (per chip) at once.
113+
114+
We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices.
115+
116+
The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we'll get the exact same results. Feel free to use different seeds when exploring results later in the notebook.
117+
118+
```python
119+
def create_key(seed=0):
120+
return jax.random.PRNGKey(seed)
121+
```
122+
123+
We obtain a rng and then "split" it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible.
124+
125+
```python
126+
rng = create_key(0)
127+
rng = jax.random.split(rng, jax.device_count())
128+
```
129+
130+
JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn't be able to take advantage of the optimized speed.
131+
132+
The Flax pipeline can compile the code for us if we pass `jit = True` as an argument. It will also ensure that the model runs in parallel in the 8 available devices.
133+
134+
The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about **`7s`** for future inference runs.
135+
136+
```
137+
%%time
138+
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
139+
```
140+
141+
```python out
142+
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
143+
Wall time: 1min 29s
144+
```
145+
146+
The returned array has shape `(8, 1, 512, 512, 3)`. We reshape it to get rid of the second dimension and obtain 8 images of `512 × 512 × 3` and then convert them to PIL.
147+
148+
```python
149+
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
150+
images = pipeline.numpy_to_pil(images)
151+
```
152+
153+
### Visualization
154+
155+
Let's create a helper function to display images in a grid.
156+
157+
```python
158+
def image_grid(imgs, rows, cols):
159+
w, h = imgs[0].size
160+
grid = Image.new("RGB", size=(cols * w, rows * h))
161+
for i, img in enumerate(imgs):
162+
grid.paste(img, box=(i % cols * w, i // cols * h))
163+
return grid
164+
```
165+
166+
```python
167+
image_grid(images, 2, 4)
168+
```
169+
170+
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
171+
172+
173+
## Using different prompts
174+
175+
We don't have to replicate the _same_ prompt in all the devices. We can do whatever we want: generate 2 prompts 4 times each, or even generate 8 different prompts at once. Let's do that!
176+
177+
First, we'll refactor the input preparation code into a handy function:
178+
179+
```python
180+
prompts = [
181+
"Labrador in the style of Hokusai",
182+
"Painting of a squirrel skating in New York",
183+
"HAL-9000 in the style of Van Gogh",
184+
"Times Square under water, with fish and a dolphin swimming around",
185+
"Ancient Roman fresco showing a man working on his laptop",
186+
"Close-up photograph of young black woman against urban background, high quality, bokeh",
187+
"Armchair in the shape of an avocado",
188+
"Clown astronaut in space, with Earth in the background",
189+
]
190+
```
191+
192+
```python
193+
prompt_ids = pipeline.prepare_inputs(prompts)
194+
prompt_ids = shard(prompt_ids)
195+
196+
images = pipeline(prompt_ids, p_params, rng, jit=True).images
197+
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
198+
images = pipeline.numpy_to_pil(images)
199+
200+
image_grid(images, 2, 4)
201+
```
202+
203+
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
204+
205+
206+
## How does parallelization work?
207+
208+
We said before that the `diffusers` Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We'll now briefly look inside that process to show how it works.
209+
210+
JAX parallelization can be done in multiple ways. The easiest one revolves around using the `jax.pmap` function to achieve single-program, multiple-data (SPMD) parallelization. It means we'll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) and the [`pjit` pages](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit) to explore this topic if you are interested!
211+
212+
`jax.pmap` does two things for us:
213+
- Compiles (or `jit`s) the code, as if we had invoked `jax.jit()`. This does not happen when we call `pmap`, but the first time the pmapped function is invoked.
214+
- Ensures the compiled code runs in parallel in all the available devices.
215+
216+
To show how it works we `pmap` the `_generate` method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of `diffusers`.
217+
218+
```python
219+
p_generate = pmap(pipeline._generate)
220+
```
221+
222+
After we use `pmap`, the prepared function `p_generate` will conceptually do the following:
223+
* Invoke a copy of the underlying function `pipeline._generate` in each device.
224+
* Send each device a different portion of the input arguments. That's what sharding is used for. In our case, `prompt_ids` has shape `(8, 1, 77, 768)`. This array will be split in `8` and each copy of `_generate` will receive an input with shape `(1, 77, 768)`.
225+
226+
We can code `_generate` completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (`1` in this example) and the dimensions that make sense for our code, and don't have to change anything to make it work in parallel.
227+
228+
The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster.
229+
230+
```
231+
%%time
232+
images = p_generate(prompt_ids, p_params, rng)
233+
images = images.block_until_ready()
234+
images.shape
235+
```
236+
237+
```python out
238+
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
239+
Wall time: 1min 15s
240+
```
241+
242+
```python
243+
images.shape
244+
```
245+
246+
```python out
247+
(8, 1, 512, 512, 3)
248+
```
249+
250+
We use `block_until_ready()` to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized.

0 commit comments

Comments
 (0)