Skip to content

[Examples] InstructPix2Pix instruct training script #2478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
38064ed
add: initial implementation of the pix2pix instruct training script.
sayakpaul Feb 24, 2023
2c82597
shorten cli arg.
sayakpaul Feb 24, 2023
675f399
fix: main process check.
sayakpaul Feb 24, 2023
534aa56
fix: dataset column names.
sayakpaul Feb 24, 2023
9b3bbce
simplify tokenization.
sayakpaul Feb 24, 2023
a64f1bc
proper placement of null conditions.
sayakpaul Feb 24, 2023
8799d77
apply styling.
sayakpaul Feb 24, 2023
e7aaa2e
remove debugging message for conditioning do.
sayakpaul Feb 24, 2023
e779443
complete license.
sayakpaul Feb 24, 2023
b1e9e58
add: requirements.tzt
sayakpaul Feb 24, 2023
b16bd74
wandb column name order.
sayakpaul Feb 24, 2023
fe8f5c8
fix: augmentation.
sayakpaul Mar 14, 2023
03ab82d
change: dataset_id.
sayakpaul Mar 14, 2023
0c9fd19
fix: convert_to_np() call.
sayakpaul Mar 14, 2023
e6d09f0
fix: reshaping.
sayakpaul Mar 14, 2023
66f4b31
fix: final ema copy.
sayakpaul Mar 15, 2023
daaff42
Apply suggestions from code review
sayakpaul Mar 15, 2023
090a032
address PR comments.
sayakpaul Mar 15, 2023
eb3b7ca
add: readme details.
sayakpaul Mar 15, 2023
81ae464
config fix.
sayakpaul Mar 15, 2023
e0b17e5
downgrade version.
sayakpaul Mar 15, 2023
c3ebe7b
reduce image width in the readme.
sayakpaul Mar 15, 2023
1807add
note on hyperparameters during generation.
sayakpaul Mar 15, 2023
f12c041
add: output images.
sayakpaul Mar 15, 2023
30dd619
Merge branch 'main' into training/instruct-pix2pix
sayakpaul Mar 15, 2023
2d6d602
Merge branch 'main' into training/instruct-pix2pix
sayakpaul Mar 15, 2023
3627f96
update readme.
sayakpaul Mar 16, 2023
ac15648
minor edits to readme.
sayakpaul Mar 16, 2023
a1c14ff
debugging statement.
sayakpaul Mar 16, 2023
d2effc4
explicitly placement of the pipeline.
sayakpaul Mar 16, 2023
172599a
bump minimum diffusers version.
sayakpaul Mar 16, 2023
e0440e6
fix: device attribute error.
sayakpaul Mar 16, 2023
49364b7
weight dtype.
sayakpaul Mar 16, 2023
8b728dc
debugging.
sayakpaul Mar 16, 2023
f3b2f53
add dtype inform.
sayakpaul Mar 16, 2023
8faaab3
add seoarate te and vae.
sayakpaul Mar 16, 2023
52c75f9
add: explicit casting/
sayakpaul Mar 16, 2023
e0c78de
remove casting.
sayakpaul Mar 16, 2023
0074d05
up.
sayakpaul Mar 16, 2023
ba2855e
up 2.
sayakpaul Mar 16, 2023
59cd2da
up 3.
sayakpaul Mar 16, 2023
a53cc41
autocast.
sayakpaul Mar 17, 2023
4c8920c
disable mixed-precision in the final inference.
sayakpaul Mar 17, 2023
fbc626a
debugging information.
sayakpaul Mar 17, 2023
4e1545b
autocasting.
sayakpaul Mar 20, 2023
b6b1d74
add: instructpix2pix training section to the docs.
sayakpaul Mar 21, 2023
901a863
Merge branch 'main' into training/instruct-pix2pix
sayakpaul Mar 21, 2023
4813635
Merge branch 'main' into training/instruct-pix2pix
sayakpaul Mar 23, 2023
b499450
Empty-Commit
sayakpaul Mar 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@
- local: optimization/habana
title: Habana Gaudi
title: Optimization/Special Hardware
- sections:
- local: training/overview
title: Overview
- local: training/unconditional_training
title: Unconditional image generation
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
title: DreamBooth
- local: training/text2image
title: Text-to-image
- local: training/lora
title: Low-Rank Adaptation of Large Language Models (LoRA)
- local: training/instructpix2pix
title: InstructPix2Pix Training
title: Training
- sections:
- local: conceptual/philosophy
title: Philosophy
Expand Down
181 changes: 181 additions & 0 deletions docs/source/en/training/instructpix2pix.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# InstructPix2Pix

[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:

<p align="center">
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png" alt="instructpix2pix-inputs" width=600/>
</p>

The output is an "edited" image that reflects the edit instruction applied on the input image:

<p align="center">
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/output-gs%407-igs%401-steps%4050.png" alt="instructpix2pix-output" width=600/>
</p>

The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.

***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix
training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***

## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the example folder and run
```bash
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell e.g. a notebook

```python
from accelerate.utils import write_basic_config

write_basic_config()
```

### Toy example

As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.

Configure environment variables such as the dataset identifier and the Stable Diffusion
checkpoint:

```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_ID="fusing/instructpix2pix-1000-samples"
```

Now, we can launch training:

```bash
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--seed=42
```

Additionally, we support performing validation inference to monitor training progress
with Weights and Biases. You can enable this feature with `report_to="wandb"`:

```bash
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb
```

We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.

[Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.

***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***

## Inference

Once training is complete, we can perform inference:

```python
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline

model_id = "your_model_id" # <- replace this
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"


def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image


image = download_image(url)
prompt = "wipe out the lake"
num_inference_steps = 20
image_guidance_scale = 1.5
guidance_scale = 10

edited_image = pipe(
prompt,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
edited_image.save("edited_image.png")
```

An example model repo obtained using this training script can be found
here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).

We encourage you to play with the following three parameters to control
speed and quality during performance:

* `num_inference_steps`
* `image_guidance_scale`
* `guidance_scale`

Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
166 changes: 166 additions & 0 deletions examples/instruct_pix2pix/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# InstructPix2Pix training example

[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:

<p align="center">
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png" alt="instructpix2pix-inputs" width=600/>
</p>

The output is an "edited" image that reflects the edit instruction applied on the input image:

<p align="center">
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/output-gs%407-igs%401-steps%4050.png" alt="instructpix2pix-output" width=600/>
</p>

The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.

***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix
training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***

## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the example folder and run
```bash
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell e.g. a notebook

```python
from accelerate.utils import write_basic_config
write_basic_config()
```

### Toy example

As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.

Configure environment variables such as the dataset identifier and the Stable Diffusion
checkpoint:

```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_ID="fusing/instructpix2pix-1000-samples"
```

Now, we can launch training:

```bash
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--seed=42
```

Additionally, we support performing validation inference to monitor training progress
with Weights and Biases. You can enable this feature with `report_to="wandb"`:

```bash
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb
```

We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.

[Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.

***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***

## Inference

Once training is complete, we can perform inference:

```python
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline

model_id = "your_model_id" # <- replace this
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"


def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image

image = download_image(url)
prompt = "wipe out the lake"
num_inference_steps = 20
image_guidance_scale = 1.5
guidance_scale = 10

edited_image = pipe(prompt,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
edited_image.save("edited_image.png")
```

An example model repo obtained using this training script can be found
here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).

We encourage you to play with the following three parameters to control
speed and quality during performance:

* `num_inference_steps`
* `image_guidance_scale`
* `guidance_scale`

Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
6 changes: 6 additions & 0 deletions examples/instruct_pix2pix/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
accelerate
torchvision
transformers>=4.25.1
datasets
ftfy
tensorboard
Loading