Skip to content

dreambooth upscaling fix added latents #3659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II.

For finegrained detail like faces, we find that lower learning rates work best.
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.

For stage II, we find that lower learning rates are also needed.

We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

used in the training scripts.

### Stage II additional validation images

The stage II validation requires images to upscale, we can download a downsized version of the training set:
Expand Down Expand Up @@ -631,7 +634,8 @@ with a T5 loaded from the original model.

`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.

`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.

Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.

Expand All @@ -656,18 +660,22 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \
--tokenizer_max_length 77 \
--pre_compute_text_embeddings \
--use_8bit_adam \ #
--use_8bit_adam \
--set_grads_to_none \
--skip_save_text_encoder \
--push_to_hub
```

### IF Stage II Full Dreambooth

`--learning_rate=1e-8`: Even lower learning rate.
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.

`--resolution=256`: The upscaler expects higher resolution inputs

`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
Expand All @@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=2 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-8 \
--gradient_accumulation_steps=6 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, makes sense that bigger batch size will help here

--learning_rate=5e-6 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_steps=150 \
Expand Down
20 changes: 14 additions & 6 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II.

For finegrained detail like faces, we find that lower learning rates work best.
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.

For stage II, we find that lower learning rates are also needed.

We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
used in the training scripts.

### Stage II additional validation images

The stage II validation requires images to upscale, we can download a downsized version of the training set:
Expand Down Expand Up @@ -665,7 +668,8 @@ with a T5 loaded from the original model.

`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.

`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.

Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.

Expand All @@ -690,18 +694,22 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \
--tokenizer_max_length 77 \
--pre_compute_text_embeddings \
--use_8bit_adam \ #
--use_8bit_adam \
--set_grads_to_none \
--skip_save_text_encoder \
--push_to_hub
```

### IF Stage II Full Dreambooth

`--learning_rate=1e-8`: Even lower learning rate.
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.

`--resolution=256`: The upscaler expects higher resolution inputs

`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.

```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
Expand All @@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \
--resolution=256 \
--train_batch_size=2 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-8 \
--gradient_accumulation_steps=6 \
--learning_rate=5e-6 \
--max_train_steps=2000 \
--validation_prompt="a sks dog" \
--validation_steps=150 \
Expand Down
11 changes: 2 additions & 9 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor


if is_wandb_available():
Expand Down Expand Up @@ -1212,14 +1211,8 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if unet.config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
Comment on lines +1214 to +1215
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just doubling the existing latents is sufficient as the we use the same timestep for both of the conditioning inputs!

Theoretically, we could do something along the lines of training on two separate noising amounts for each set of channels or bias one towards a lower amount of noise as the pipeline's default is for a quarter of total noise to be added to the upsampled input.

However, this is the simplest implementation for now so I think it's ok!


if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
Expand Down
11 changes: 2 additions & 9 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -1157,14 +1156,8 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if unet.config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
Expand Down