diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 908355e496dc..88ded0e009dc 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit. -Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path: +Let's try DreamBooth with a +[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example); +download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path: + +```python +local_dir = "./path_to_training_images" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 1c72fbbc8d58..ac2311df9f1e 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License. -Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. +Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also +support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support +LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918). @@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \ --push_to_hub ``` +It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads +to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA, +specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script. + + ### Inference[[dreambooth-inference]] Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]: diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index d53f17114404..8447c7560720 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -45,15 +45,28 @@ write_basic_config() ### Dog toy example -Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. -And launch the training using +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +And launch the training using: **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export OUTPUT_DIR="path-to-save-model" accelerate launch train_dreambooth.py \ @@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples` ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment. ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https: ```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export OUTPUT_DIR="path-to-save-model" ``` @@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). You can use the `Step` slider to see how the model learned the features of our subject while the model trained. +Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we +enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918). + +With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth). + + ### Inference After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to @@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt ```bash export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export OUTPUT_DIR="path-to-save-model" python train_dreambooth_flax.py \ @@ -405,7 +424,7 @@ python train_dreambooth_flax.py \ ```bash export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" @@ -429,7 +448,7 @@ python train_dreambooth_flax.py \ ```bash export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" -export INSTANCE_DIR="path-to-instance-images" +export INSTANCE_DIR="dog" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index d360939c8c0c..1b75402c3550 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -15,6 +15,7 @@ import argparse import hashlib +import itertools import logging import math import os @@ -43,12 +44,13 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.loaders import AttnProcsLayers +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -58,7 +60,7 @@ logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None): +def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -83,6 +85,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_ These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n {img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -219,6 +223,11 @@ def parse_args(input_args=None): " cropped. The images will be resized to the resolution first before cropping." ), ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -547,7 +556,13 @@ def main(args): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -691,7 +706,7 @@ def main(args): # => 32 layers # Set correct lora layers - lora_attn_procs = {} + unet_lora_attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -703,12 +718,33 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) + unet_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) - accelerator.register_for_checkpointing(lora_layers) + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + accelerator.register_for_checkpointing(unet_lora_layers) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, + # we first load a dummy pipeline with the text encoder and then do the monkey-patching. + text_encoder_lora_layers = None + if args.train_text_encoder: + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + temp_pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=text_encoder + ) + temp_pipeline._modify_text_encoder(text_lora_attn_procs) + text_encoder = temp_pipeline.text_encoder + accelerator.register_for_checkpointing(unet_lora_layers) + del temp_pipeline if args.scale_lr: args.learning_rate = ( @@ -739,8 +775,13 @@ def main(args): optimizer_class = torch.optim.AdamW # Optimizer creation + params_to_optimize = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) optimizer = optimizer_class( - lora_layers.parameters(), + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -784,9 +825,14 @@ def main(args): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + else: + unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -845,6 +891,8 @@ def main(args): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -900,7 +948,11 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -914,7 +966,14 @@ def main(args): if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) + # We combine the text encoder and UNet LoRA parameters with a simple + # custom logic. `accelerator.save_state()` won't know that. So, + # use `LoraLoaderMixin.save_lora_weights()`. + LoraLoaderMixin.save_lora_weights( + save_directory=save_path, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -970,7 +1029,12 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) + text_encoder = text_encoder.to(torch.float32) + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) # Final inference # Load previous pipeline @@ -981,7 +1045,7 @@ def main(args): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + pipeline.load_attn_procs(args.output_dir) # run inference if args.validation_prompt and args.num_validation_images > 0: @@ -1010,6 +1074,7 @@ def main(args): repo_id, images=images, base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, ) diff --git a/examples/test_examples.py b/examples/test_examples.py index a77fa4c7da23..238dc49d729f 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -23,6 +23,7 @@ import unittest from typing import List +import torch from accelerate.utils import write_basic_config from diffusers import DiffusionPipeline, UNet2DConditionModel @@ -221,6 +222,68 @@ def test_dreambooth_checkpointing(self): self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + def test_dreambooth_lora(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --train_text_encoder + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + keys = lora_state_dict.keys() + is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) + self.assertTrue(is_correct_naming) + def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f"""