Skip to content

Commit 0fc2fb7

Browse files
dreambooth upscaling fix added latents (#3659)
1 parent 523a50a commit 0fc2fb7

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
540540
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
541541
LoRA finetuning stage II.
542542

543-
For finegrained detail like faces, we find that lower learning rates work best.
543+
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
544544

545545
For stage II, we find that lower learning rates are also needed.
546546

547+
We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
548+
used in the training scripts.
549+
547550
### Stage II additional validation images
548551

549552
The stage II validation requires images to upscale, we can download a downsized version of the training set:
@@ -631,7 +634,8 @@ with a T5 loaded from the original model.
631634

632635
`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
633636

634-
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
637+
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
638+
likely the learning rate can be increased with larger batch sizes.
635639

636640
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
637641

@@ -656,18 +660,22 @@ accelerate launch train_dreambooth.py \
656660
--text_encoder_use_attention_mask \
657661
--tokenizer_max_length 77 \
658662
--pre_compute_text_embeddings \
659-
--use_8bit_adam \ #
663+
--use_8bit_adam \
660664
--set_grads_to_none \
661665
--skip_save_text_encoder \
662666
--push_to_hub
663667
```
664668

665669
### IF Stage II Full Dreambooth
666670

667-
`--learning_rate=1e-8`: Even lower learning rate.
671+
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
672+
1e-8.
668673

669674
`--resolution=256`: The upscaler expects higher resolution inputs
670675

676+
`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
677+
faces required large effective batch sizes.
678+
671679
```sh
672680
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
673681
export INSTANCE_DIR="dog"
@@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \
682690
--instance_prompt="a sks dog" \
683691
--resolution=256 \
684692
--train_batch_size=2 \
685-
--gradient_accumulation_steps=2 \
686-
--learning_rate=1e-8 \
693+
--gradient_accumulation_steps=6 \
694+
--learning_rate=5e-6 \
687695
--max_train_steps=2000 \
688696
--validation_prompt="a sks dog" \
689697
--validation_steps=150 \

examples/dreambooth/README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
574574
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
575575
LoRA finetuning stage II.
576576

577-
For finegrained detail like faces, we find that lower learning rates work best.
577+
For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
578578

579579
For stage II, we find that lower learning rates are also needed.
580580

581+
We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
582+
used in the training scripts.
583+
581584
### Stage II additional validation images
582585

583586
The stage II validation requires images to upscale, we can download a downsized version of the training set:
@@ -665,7 +668,8 @@ with a T5 loaded from the original model.
665668

666669
`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
667670

668-
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
671+
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
672+
likely the learning rate can be increased with larger batch sizes.
669673

670674
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
671675

@@ -690,18 +694,22 @@ accelerate launch train_dreambooth.py \
690694
--text_encoder_use_attention_mask \
691695
--tokenizer_max_length 77 \
692696
--pre_compute_text_embeddings \
693-
--use_8bit_adam \ #
697+
--use_8bit_adam \
694698
--set_grads_to_none \
695699
--skip_save_text_encoder \
696700
--push_to_hub
697701
```
698702

699703
### IF Stage II Full Dreambooth
700704

701-
`--learning_rate=1e-8`: Even lower learning rate.
705+
`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
706+
1e-8.
702707

703708
`--resolution=256`: The upscaler expects higher resolution inputs
704709

710+
`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
711+
faces required large effective batch sizes.
712+
705713
```sh
706714
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
707715
export INSTANCE_DIR="dog"
@@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \
716724
--instance_prompt="a sks dog" \
717725
--resolution=256 \
718726
--train_batch_size=2 \
719-
--gradient_accumulation_steps=2 \
720-
--learning_rate=1e-8 \
727+
--gradient_accumulation_steps=6 \
728+
--learning_rate=5e-6 \
721729
--max_train_steps=2000 \
722730
--validation_prompt="a sks dog" \
723731
--validation_steps=150 \

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from diffusers.optimization import get_scheduler
5353
from diffusers.utils import check_min_version, is_wandb_available
5454
from diffusers.utils.import_utils import is_xformers_available
55-
from diffusers.utils.torch_utils import randn_tensor
5655

5756

5857
if is_wandb_available():
@@ -1212,14 +1211,8 @@ def compute_text_embeddings(prompt):
12121211
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
12131212
)
12141213

1215-
if unet.config.in_channels > channels:
1216-
needed_additional_channels = unet.config.in_channels - channels
1217-
additional_latents = randn_tensor(
1218-
(bsz, needed_additional_channels, height, width),
1219-
device=noisy_model_input.device,
1220-
dtype=noisy_model_input.dtype,
1221-
)
1222-
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
1214+
if unet.config.in_channels == channels * 2:
1215+
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
12231216

12241217
if args.class_labels_conditioning == "timesteps":
12251218
class_labels = timesteps

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
from diffusers.optimization import get_scheduler
6161
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
6262
from diffusers.utils.import_utils import is_xformers_available
63-
from diffusers.utils.torch_utils import randn_tensor
6463

6564

6665
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1157,14 +1156,8 @@ def compute_text_embeddings(prompt):
11571156
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
11581157
)
11591158

1160-
if unet.config.in_channels > channels:
1161-
needed_additional_channels = unet.config.in_channels - channels
1162-
additional_latents = randn_tensor(
1163-
(bsz, needed_additional_channels, height, width),
1164-
device=noisy_model_input.device,
1165-
dtype=noisy_model_input.dtype,
1166-
)
1167-
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
1159+
if unet.config.in_channels == channels * 2:
1160+
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
11681161

11691162
if args.class_labels_conditioning == "timesteps":
11701163
class_labels = timesteps

0 commit comments

Comments
 (0)