diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 7094570ad5eb..62ab1b88c0cb 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -160,6 +160,39 @@ accelerate launch train_dreambooth.py \ --mixed_precision=fp16 ``` +### Fine-tune text encoder with the UNet. + +The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces. +Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`. + +___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___ + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --use_8bit_adam + --gradient_checkpointing \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + ## Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ca39aeff236b..40d5bca45de9 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,4 +1,5 @@ import argparse +import itertools import math import os from pathlib import Path @@ -100,6 +101,7 @@ def parse_args(): parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -320,6 +322,15 @@ def main(): logging_dir=logging_dir, ) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + if args.seed is not None: set_seed(args.seed) @@ -385,8 +396,14 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() if args.scale_lr: args.learning_rate = ( @@ -406,8 +423,11 @@ def main(): else: optimizer_class = torch.optim.AdamW + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( - unet.parameters(), # only optimize unet + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -467,9 +487,14 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -480,8 +505,9 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -516,9 +542,8 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -532,8 +557,7 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -556,7 +580,12 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -578,7 +607,9 @@ def collate_fn(examples): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet) + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), ) pipeline.save_pretrained(args.output_dir)