Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 148 additions & 16 deletions docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,65 @@ You may also run inference from any of the [saved training checkpoints](#inferen

## IF

You can use the lora and full dreambooth scripts to also train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). A few alternative cli flags are needed due to the model size, the expected input resolution, and the text encoder conventions.
You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler
[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).

### LoRA Dreambooth
Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed
variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you
must also update the pipeline's scheduler config.

```py
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")

pipe.load_lora_weights("<lora weights path>")

# Update scheduler config to fixed variance schedule
pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
pipe.scheduler.register_to_config(variance_type="fixed_small")

should suffice no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't recommend mutating existing constructed schedulers since depending on the argument they sometimes change other instance variables. Since they're cheap to construct, I think it's good to always demonstrate freshly constructing a new instance

```

Additionally, a few alternative cli flags are needed for IF.

`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution.

`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate
T5.

`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number.

`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder.

### Tips and Tricks
We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless.

For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the
upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt.

For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice info!

I would enumerate the tips and tricks to improve readability.

Also, if we can include how the upscaler is supposed to be fine-tuning or even just leave a link that shows it, that would be info-complete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what should we link to? The stage II fine tuning section is just a bit further on in the page

LoRA finetuning stage II.

For finegrained detail like faces, we find that lower learning rates work best.

For stage II, we find that lower learning rates are also needed.

### Stage II additional validation images

The stage II validation requires images to upscale, we can download a downsized version of the training set:

```py
from huggingface_hub import snapshot_download

local_dir = "./dog_downsized"
snapshot_download(
"diffusers/dog-example-downsized",
local_dir=local_dir,
repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

### IF stage I LoRA Dreambooth
This training configuration requires ~28 GB VRAM.

```sh
Expand All @@ -518,7 +574,7 @@ accelerate launch train_dreambooth_lora.py \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a sks dog" \
--resolution=64 \ # The input resolution of the IF unet is 64x64
--resolution=64 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
Expand All @@ -527,16 +583,57 @@ accelerate launch train_dreambooth_lora.py \
--validation_prompt="a sks dog" \
--validation_epochs=25 \
--checkpointing_steps=100 \
--pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory
--tokenizer_max_length=77 \ # IF expects an override of the max token length
--text_encoder_use_attention_mask # IF expects attention mask for text embeddings
--pre_compute_text_embeddings \
--tokenizer_max_length=77 \
--text_encoder_use_attention_mask
```

### Full Dreambooth
Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
Using 8bit adam and the rest of the following config, the model can be trained in ~48 GB VRAM.
### IF stage II LoRA Dreambooth

For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
`--validation_images`: These images are upscaled during validation steps.

`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II.

`--learning_rate=1e-6`: Lower learning rate than stage I.

`--resolution=256`: The upscaler expects higher resolution inputs

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="dreambooth_dog_upscale"
export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"

python train_dreambooth_lora.py \
--report_to wandb \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_epochs=100 \
--checkpointing_steps=500 \
--pre_compute_text_embeddings \
--tokenizer_max_length=77 \
--text_encoder_use_attention_mask \
--validation_images $VALIDATION_IMAGES \
--class_labels_conditioning=timesteps
```

### IF Stage I Full Dreambooth
`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline
with a T5 loaded from the original model.

`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.

`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.

Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.

```sh
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
Expand All @@ -549,17 +646,52 @@ accelerate launch train_dreambooth.py \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=64 \ # The input resolution of the IF unet is 64x64
--resolution=64 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-7 \
--max_train_steps=150 \
--validation_prompt "a photo of sks dog" \
--validation_steps 25 \
--text_encoder_use_attention_mask \ # IF expects attention mask for text embeddings
--tokenizer_max_length 77 \ # IF expects an override of the max token length
--pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory
--text_encoder_use_attention_mask \
--tokenizer_max_length 77 \
--pre_compute_text_embeddings \
--use_8bit_adam \ #
--set_grads_to_none \
--skip_save_text_encoder # do not save the full T5 text encoder with the model
```
--skip_save_text_encoder \
--push_to_hub
```

### IF Stage II Full Dreambooth

`--learning_rate=1e-8`: Even lower learning rate.

`--resolution=256`: The upscaler expects higher resolution inputs

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="dreambooth_dog_upscale"
export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"

accelerate launch train_dreambooth.py \
--report_to wandb \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=2 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-8 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_steps=150 \
--checkpointing_steps=500 \
--pre_compute_text_embeddings \
--tokenizer_max_length=77 \
--text_encoder_use_attention_mask \
--validation_images $VALIDATION_IMAGES \
--class_labels_conditioning timesteps \
--push_to_hub
```
Loading