From 7e634a6d620ac6c2688cf3cd6f1d7cd1f5894fb1 Mon Sep 17 00:00:00 2001 From: William Berman Date: Fri, 2 Jun 2023 19:53:52 -0700 Subject: [PATCH] dreambooth upscaling fix added latents --- docs/source/en/training/dreambooth.mdx | 20 ++++++++++++++------ examples/dreambooth/README.md | 20 ++++++++++++++------ examples/dreambooth/train_dreambooth.py | 11 ++--------- examples/dreambooth/train_dreambooth_lora.py | 11 ++--------- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 9bba9df5bffc..c26762d4a75d 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than LoRA finetuning stage II. -For finegrained detail like faces, we find that lower learning rates work best. +For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best. For stage II, we find that lower learning rates are also needed. +We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler +used in the training scripts. + ### Stage II additional validation images The stage II validation requires images to upscale, we can download a downsized version of the training set: @@ -631,7 +634,8 @@ with a T5 loaded from the original model. `use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. -`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. +`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is +likely the learning rate can be increased with larger batch sizes. Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. @@ -656,7 +660,7 @@ accelerate launch train_dreambooth.py \ --text_encoder_use_attention_mask \ --tokenizer_max_length 77 \ --pre_compute_text_embeddings \ - --use_8bit_adam \ # + --use_8bit_adam \ --set_grads_to_none \ --skip_save_text_encoder \ --push_to_hub @@ -664,10 +668,14 @@ accelerate launch train_dreambooth.py \ ### IF Stage II Full Dreambooth -`--learning_rate=1e-8`: Even lower learning rate. +`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as +1e-8. `--resolution=256`: The upscaler expects higher resolution inputs +`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with +faces required large effective batch sizes. + ```sh export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" export INSTANCE_DIR="dog" @@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \ --instance_prompt="a sks dog" \ --resolution=256 \ --train_batch_size=2 \ - --gradient_accumulation_steps=2 \ - --learning_rate=1e-8 \ + --gradient_accumulation_steps=6 \ + --learning_rate=5e-6 \ --max_train_steps=2000 \ --validation_prompt="a sks dog" \ --validation_steps=150 \ diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 339152915adc..5813c42cd5d3 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than LoRA finetuning stage II. -For finegrained detail like faces, we find that lower learning rates work best. +For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best. For stage II, we find that lower learning rates are also needed. +We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler +used in the training scripts. + ### Stage II additional validation images The stage II validation requires images to upscale, we can download a downsized version of the training set: @@ -665,7 +668,8 @@ with a T5 loaded from the original model. `use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. -`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. +`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is +likely the learning rate can be increased with larger batch sizes. Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. @@ -690,7 +694,7 @@ accelerate launch train_dreambooth.py \ --text_encoder_use_attention_mask \ --tokenizer_max_length 77 \ --pre_compute_text_embeddings \ - --use_8bit_adam \ # + --use_8bit_adam \ --set_grads_to_none \ --skip_save_text_encoder \ --push_to_hub @@ -698,10 +702,14 @@ accelerate launch train_dreambooth.py \ ### IF Stage II Full Dreambooth -`--learning_rate=1e-8`: Even lower learning rate. +`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as +1e-8. `--resolution=256`: The upscaler expects higher resolution inputs +`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with +faces required large effective batch sizes. + ```sh export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" export INSTANCE_DIR="dog" @@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \ --instance_prompt="a sks dog" \ --resolution=256 \ --train_batch_size=2 \ - --gradient_accumulation_steps=2 \ - --learning_rate=1e-8 \ + --gradient_accumulation_steps=6 \ + --learning_rate=5e-6 \ --max_train_steps=2000 \ --validation_prompt="a sks dog" \ --validation_steps=150 \ diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e4ab6b2ae014..ad03829fd1bc 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -52,7 +52,6 @@ from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import randn_tensor if is_wandb_available(): @@ -1212,14 +1211,8 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels > channels: - needed_additional_channels = unet.config.in_channels - channels - additional_latents = randn_tensor( - (bsz, needed_additional_channels, height, width), - device=noisy_model_input.device, - dtype=noisy_model_input.dtype, - ) - noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1) + if unet.config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": class_labels = timesteps diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 319348bd40bb..49aef1cc4a99 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -60,7 +60,6 @@ from diffusers.optimization import get_scheduler from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import randn_tensor # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1157,14 +1156,8 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels > channels: - needed_additional_channels = unet.config.in_channels - channels - additional_latents = randn_tensor( - (bsz, needed_additional_channels, height, width), - device=noisy_model_input.device, - dtype=noisy_model_input.dtype, - ) - noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1) + if unet.config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": class_labels = timesteps