|
| 1 | +# 🧨 Stable Diffusion in JAX / Flax ! |
| 2 | + |
| 3 | +[[open-in-colab]] |
| 4 | + |
| 5 | +🤗 Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. |
| 6 | + |
| 7 | +This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this notebook](https://huggingface.co/docs/diffusers/stable_diffusion). |
| 8 | + |
| 9 | +First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting. |
| 10 | + |
| 11 | +Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel. |
| 12 | + |
| 13 | +## Setup |
| 14 | + |
| 15 | +First make sure diffusers is installed. |
| 16 | + |
| 17 | +```bash |
| 18 | +!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy |
| 19 | +!pip install diffusers |
| 20 | +``` |
| 21 | + |
| 22 | +```python |
| 23 | +import jax.tools.colab_tpu |
| 24 | + |
| 25 | +jax.tools.colab_tpu.setup_tpu() |
| 26 | +import jax |
| 27 | +``` |
| 28 | + |
| 29 | +```python |
| 30 | +num_devices = jax.device_count() |
| 31 | +device_type = jax.devices()[0].device_kind |
| 32 | + |
| 33 | +print(f"Found {num_devices} JAX devices of type {device_type}.") |
| 34 | +assert ( |
| 35 | + "TPU" in device_type |
| 36 | +), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator" |
| 37 | +``` |
| 38 | + |
| 39 | +```python out |
| 40 | +Found 8 JAX devices of type Cloud TPU. |
| 41 | +``` |
| 42 | + |
| 43 | +Then we import all the dependencies. |
| 44 | + |
| 45 | +```python |
| 46 | +import numpy as np |
| 47 | +import jax |
| 48 | +import jax.numpy as jnp |
| 49 | + |
| 50 | +from pathlib import Path |
| 51 | +from jax import pmap |
| 52 | +from flax.jax_utils import replicate |
| 53 | +from flax.training.common_utils import shard |
| 54 | +from PIL import Image |
| 55 | + |
| 56 | +from huggingface_hub import notebook_login |
| 57 | +from diffusers import FlaxStableDiffusionPipeline |
| 58 | +``` |
| 59 | + |
| 60 | +## Model Loading |
| 61 | + |
| 62 | +TPU devices support `bfloat16`, an efficient half-float type. We'll use it for our tests, but you can also use `float32` to use full precision instead. |
| 63 | + |
| 64 | +```python |
| 65 | +dtype = jnp.bfloat16 |
| 66 | +``` |
| 67 | + |
| 68 | +Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a `bf16` version of the weights, which leads to type warnings that you can safely ignore. |
| 69 | + |
| 70 | +```python |
| 71 | +pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( |
| 72 | + "CompVis/stable-diffusion-v1-4", |
| 73 | + revision="bf16", |
| 74 | + dtype=dtype, |
| 75 | +) |
| 76 | +``` |
| 77 | + |
| 78 | +## Inference |
| 79 | + |
| 80 | +Since TPUs usually have 8 devices working in parallel, we'll replicate our prompt as many times as devices we have. Then we'll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we'll get 8 images in the same amount of time it takes for one chip to generate a single one. |
| 81 | + |
| 82 | +After replicating the prompt, we obtain the tokenized text ids by invoking the `prepare_inputs` function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model. |
| 83 | + |
| 84 | +```python |
| 85 | +prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" |
| 86 | +prompt = [prompt] * jax.device_count() |
| 87 | +prompt_ids = pipeline.prepare_inputs(prompt) |
| 88 | +prompt_ids.shape |
| 89 | +``` |
| 90 | + |
| 91 | +```python out |
| 92 | +(8, 77) |
| 93 | +``` |
| 94 | + |
| 95 | +### Replication and parallelization |
| 96 | + |
| 97 | +Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using `flax.jax_utils.replicate`, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`. |
| 98 | + |
| 99 | +```python |
| 100 | +p_params = replicate(params) |
| 101 | +``` |
| 102 | + |
| 103 | +```python |
| 104 | +prompt_ids = shard(prompt_ids) |
| 105 | +prompt_ids.shape |
| 106 | +``` |
| 107 | + |
| 108 | +```python out |
| 109 | +(8, 1, 77) |
| 110 | +``` |
| 111 | + |
| 112 | +That shape means that each one of the `8` devices will receive as an input a `jnp` array with shape `(1, 77)`. `1` is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than `1` if we wanted to generate multiple images (per chip) at once. |
| 113 | + |
| 114 | +We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices. |
| 115 | + |
| 116 | +The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we'll get the exact same results. Feel free to use different seeds when exploring results later in the notebook. |
| 117 | + |
| 118 | +```python |
| 119 | +def create_key(seed=0): |
| 120 | + return jax.random.PRNGKey(seed) |
| 121 | +``` |
| 122 | + |
| 123 | +We obtain a rng and then "split" it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible. |
| 124 | + |
| 125 | +```python |
| 126 | +rng = create_key(0) |
| 127 | +rng = jax.random.split(rng, jax.device_count()) |
| 128 | +``` |
| 129 | + |
| 130 | +JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn't be able to take advantage of the optimized speed. |
| 131 | + |
| 132 | +The Flax pipeline can compile the code for us if we pass `jit = True` as an argument. It will also ensure that the model runs in parallel in the 8 available devices. |
| 133 | + |
| 134 | +The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about **`7s`** for future inference runs. |
| 135 | + |
| 136 | +``` |
| 137 | +%%time |
| 138 | +images = pipeline(prompt_ids, p_params, rng, jit=True)[0] |
| 139 | +``` |
| 140 | + |
| 141 | +```python out |
| 142 | +CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s |
| 143 | +Wall time: 1min 29s |
| 144 | +``` |
| 145 | + |
| 146 | +The returned array has shape `(8, 1, 512, 512, 3)`. We reshape it to get rid of the second dimension and obtain 8 images of `512 × 512 × 3` and then convert them to PIL. |
| 147 | + |
| 148 | +```python |
| 149 | +images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) |
| 150 | +images = pipeline.numpy_to_pil(images) |
| 151 | +``` |
| 152 | + |
| 153 | +### Visualization |
| 154 | + |
| 155 | +Let's create a helper function to display images in a grid. |
| 156 | + |
| 157 | +```python |
| 158 | +def image_grid(imgs, rows, cols): |
| 159 | + w, h = imgs[0].size |
| 160 | + grid = Image.new("RGB", size=(cols * w, rows * h)) |
| 161 | + for i, img in enumerate(imgs): |
| 162 | + grid.paste(img, box=(i % cols * w, i // cols * h)) |
| 163 | + return grid |
| 164 | +``` |
| 165 | + |
| 166 | +```python |
| 167 | +image_grid(images, 2, 4) |
| 168 | +``` |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | + |
| 173 | +## Using different prompts |
| 174 | + |
| 175 | +We don't have to replicate the _same_ prompt in all the devices. We can do whatever we want: generate 2 prompts 4 times each, or even generate 8 different prompts at once. Let's do that! |
| 176 | + |
| 177 | +First, we'll refactor the input preparation code into a handy function: |
| 178 | + |
| 179 | +```python |
| 180 | +prompts = [ |
| 181 | + "Labrador in the style of Hokusai", |
| 182 | + "Painting of a squirrel skating in New York", |
| 183 | + "HAL-9000 in the style of Van Gogh", |
| 184 | + "Times Square under water, with fish and a dolphin swimming around", |
| 185 | + "Ancient Roman fresco showing a man working on his laptop", |
| 186 | + "Close-up photograph of young black woman against urban background, high quality, bokeh", |
| 187 | + "Armchair in the shape of an avocado", |
| 188 | + "Clown astronaut in space, with Earth in the background", |
| 189 | +] |
| 190 | +``` |
| 191 | + |
| 192 | +```python |
| 193 | +prompt_ids = pipeline.prepare_inputs(prompts) |
| 194 | +prompt_ids = shard(prompt_ids) |
| 195 | + |
| 196 | +images = pipeline(prompt_ids, p_params, rng, jit=True).images |
| 197 | +images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) |
| 198 | +images = pipeline.numpy_to_pil(images) |
| 199 | + |
| 200 | +image_grid(images, 2, 4) |
| 201 | +``` |
| 202 | + |
| 203 | + |
| 204 | + |
| 205 | + |
| 206 | +## How does parallelization work? |
| 207 | + |
| 208 | +We said before that the `diffusers` Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We'll now briefly look inside that process to show how it works. |
| 209 | + |
| 210 | +JAX parallelization can be done in multiple ways. The easiest one revolves around using the `jax.pmap` function to achieve single-program, multiple-data (SPMD) parallelization. It means we'll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) and the [`pjit` pages](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit) to explore this topic if you are interested! |
| 211 | + |
| 212 | +`jax.pmap` does two things for us: |
| 213 | +- Compiles (or `jit`s) the code, as if we had invoked `jax.jit()`. This does not happen when we call `pmap`, but the first time the pmapped function is invoked. |
| 214 | +- Ensures the compiled code runs in parallel in all the available devices. |
| 215 | + |
| 216 | +To show how it works we `pmap` the `_generate` method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of `diffusers`. |
| 217 | + |
| 218 | +```python |
| 219 | +p_generate = pmap(pipeline._generate) |
| 220 | +``` |
| 221 | + |
| 222 | +After we use `pmap`, the prepared function `p_generate` will conceptually do the following: |
| 223 | +* Invoke a copy of the underlying function `pipeline._generate` in each device. |
| 224 | +* Send each device a different portion of the input arguments. That's what sharding is used for. In our case, `prompt_ids` has shape `(8, 1, 77, 768)`. This array will be split in `8` and each copy of `_generate` will receive an input with shape `(1, 77, 768)`. |
| 225 | + |
| 226 | +We can code `_generate` completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (`1` in this example) and the dimensions that make sense for our code, and don't have to change anything to make it work in parallel. |
| 227 | + |
| 228 | +The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster. |
| 229 | + |
| 230 | +``` |
| 231 | +%%time |
| 232 | +images = p_generate(prompt_ids, p_params, rng) |
| 233 | +images = images.block_until_ready() |
| 234 | +images.shape |
| 235 | +``` |
| 236 | + |
| 237 | +```python out |
| 238 | +CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s |
| 239 | +Wall time: 1min 15s |
| 240 | +``` |
| 241 | + |
| 242 | +```python |
| 243 | +images.shape |
| 244 | +``` |
| 245 | + |
| 246 | +```python out |
| 247 | +(8, 1, 512, 512, 3) |
| 248 | +``` |
| 249 | + |
| 250 | +We use `block_until_ready()` to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized. |
0 commit comments