diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index ff687db5c18b..6fbac07158c1 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -113,6 +113,7 @@ jobs: - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | + python -m pip install peft python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 0579e337939d..972fe6e8cffb 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -44,6 +44,7 @@ write_basic_config() ``` When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. ### Dog toy example diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index d78d1ef5d2dd..66232d3063f5 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -47,6 +47,7 @@ write_basic_config() ``` When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. ### Dog toy example diff --git a/examples/dreambooth/requirements.txt b/examples/dreambooth/requirements.txt index 7a612982f4ab..3f86855e1d1e 100644 --- a/examples/dreambooth/requirements.txt +++ b/examples/dreambooth/requirements.txt @@ -4,3 +4,4 @@ transformers>=4.25.1 ftfy tensorboard Jinja2 +peft==0.7.0 \ No newline at end of file diff --git a/examples/dreambooth/requirements_sdxl.txt b/examples/dreambooth/requirements_sdxl.txt index 7a612982f4ab..3f86855e1d1e 100644 --- a/examples/dreambooth/requirements_sdxl.txt +++ b/examples/dreambooth/requirements_sdxl.txt @@ -4,3 +4,4 @@ transformers>=4.25.1 ftfy tensorboard Jinja2 +peft==0.7.0 \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b2905a05cef3..60213dd75685 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -16,7 +16,6 @@ import argparse import copy import gc -import itertools import logging import math import os @@ -35,6 +34,8 @@ from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset @@ -52,14 +53,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.attention_processor import ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - SlicedAttnAddedKVProcessor, -) -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -864,79 +858,19 @@ def main(args): text_encoder.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers - # It's important to realize here how many attention weights will be added and of which sizes - # The sizes of the attention layers consist only of two different variables: - # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. - # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. - - # Let's first see how many attention processors we will have to set. - # For Stable Diffusion, it should be equal to: - # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 - # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 - # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18 - # => 32 layers - - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) - - if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): - attn_module.add_k_proj.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.add_k_proj.in_features, - out_features=attn_module.add_k_proj.out_features, - rank=args.rank, - ) - ) - attn_module.add_v_proj.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.add_v_proj.in_features, - out_features=attn_module.add_v_proj.out_features, - rank=args.rank, - ) - ) - unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], + ) + unet.add_adapter(unet_lora_config) - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. + # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if args.train_text_encoder: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) + text_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + ) + text_encoder.add_adapter(text_lora_config) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -948,9 +882,9 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_lora_state_dict(model) + unet_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): - text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1010,11 +944,10 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = ( - itertools.chain(unet_lora_parameters, text_lora_parameters) - if args.train_text_encoder - else unet_lora_parameters - ) + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + if args.train_text_encoder: + params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -1257,12 +1190,7 @@ def compute_text_embeddings(prompt): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet_lora_parameters, text_lora_parameters) - if args.train_text_encoder - else unet_lora_parameters - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -1385,19 +1313,19 @@ def compute_text_embeddings(prompt): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = unet_lora_state_dict(unet) - if text_encoder is not None and args.train_text_encoder: + unet_lora_state_dict = get_peft_model_state_dict(unet) + + if args.train_text_encoder: text_encoder = accelerator.unwrap_model(text_encoder) - text_encoder = text_encoder.to(torch.float32) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + text_encoder_state_dict = get_peft_model_state_dict(text_encoder) else: - text_encoder_lora_layers = None + text_encoder_state_dict = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, + unet_lora_layers=unet_lora_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, ) # Final inference diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 98b5f3bc5694..c8a9a6ad4812 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -34,6 +34,8 @@ from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset @@ -50,9 +52,8 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr, unet_lora_state_dict +from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -1009,54 +1010,19 @@ def main(args): text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + ) + unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( - text_encoder_one, dtype=torch.float32, rank=args.rank - ) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.rank + text_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1069,11 +1035,11 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_lora_state_dict(model) + unet_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1130,6 +1096,12 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: @@ -1194,26 +1166,10 @@ def load_model_hook(models, input_dir): optimizer_class = prodigyopt.Prodigy - if args.learning_rate <= 0.1: - logger.warn( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - if args.train_text_encoder and args.text_encoder_lr: - logger.warn( - f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" - f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " - f"When using prodigy only learning_rate is used as the initial learning rate." - ) - # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be - # --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate - params_to_optimize[2]["lr"] = args.learning_rate - optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, @@ -1659,13 +1615,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = unet_lora_state_dict(unet) + unet_lora_layers = get_peft_model_state_dict(unet) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 7b9f4013c746..e2cbaca2a9d8 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -32,6 +32,8 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e accelerate config ``` +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + ### Pokemon example You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. diff --git a/examples/text_to_image/README_sdxl.md b/examples/text_to_image/README_sdxl.md index 75c9cb126472..1278185ddf1f 100644 --- a/examples/text_to_image/README_sdxl.md +++ b/examples/text_to_image/README_sdxl.md @@ -45,6 +45,7 @@ write_basic_config() ``` When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. ### Training diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index 31b9026efdc2..0dd164fc2035 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -5,3 +5,4 @@ datasets ftfy tensorboard Jinja2 +peft==0.7.0 \ No newline at end of file diff --git a/examples/text_to_image/requirements_sdxl.txt b/examples/text_to_image/requirements_sdxl.txt index cdd3336e3617..64cbc9205fd0 100644 --- a/examples/text_to_image/requirements_sdxl.txt +++ b/examples/text_to_image/requirements_sdxl.txt @@ -5,3 +5,4 @@ ftfy tensorboard Jinja2 datasets +peft==0.7.0 \ No newline at end of file diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c030c59693c3..b63500f906a8 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -34,13 +34,14 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.models.lora import LoRALinearLayer +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -479,62 +480,20 @@ def main(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + # Freeze the unet parameters before adding adapters + for param in unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + ) + # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) - # now we will add new LoRA weights to the attention layers - # It's important to realize here how many attention weights will be added and of which sizes - # The sizes of the attention layers consist only of two different variables: - # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. - # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. - - # Let's first see how many attention processors we will have to set. - # For Stable Diffusion, it should be equal to: - # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 - # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 - # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 - # => 32 layers - - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet.add_adapter(unet_lora_config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -549,6 +508,8 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -573,7 +534,7 @@ def main(): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - unet_lora_parameters, + lora_layers, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -700,8 +661,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -833,7 +794,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = unet_lora_parameters + params_to_clip = lora_layers accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -870,6 +831,15 @@ def collate_fn(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + + unet_lora_state_dict = get_peft_model_state_dict(unet) + + StableDiffusionPipeline.save_lora_weights( + save_directory=save_path, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -926,7 +896,13 @@ def collate_fn(examples): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - unet.save_attn_procs(args.output_dir) + + unet_lora_state_dict = get_peft_model_state_dict(unet) + StableDiffusionPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) if args.push_to_hub: save_model_card( diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index f1306ea5b9de..2e70c77e860e 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -16,7 +16,6 @@ """Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA.""" import argparse -import itertools import logging import math import os @@ -37,6 +36,8 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -50,7 +51,6 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -658,53 +658,20 @@ def main(args): # now we will add new LoRA weights to the attention layers # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) + unet_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + ) - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet.add_adapter(unet_lora_config) - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. + # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( - text_encoder_one, dtype=torch.float32, rank=args.rank - ) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.rank + text_lora_config = LoraConfig( + r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -717,11 +684,11 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + unet_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -792,11 +759,13 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) - if args.train_text_encoder - else unet_lora_parameters - ) + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + if args.train_text_encoder: + params_to_optimize = ( + params_to_optimize + + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -1128,12 +1097,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) - if args.train_text_encoder - else unet_lora_parameters - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -1229,20 +1193,21 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) - unet_lora_layers = unet_attn_processors_state_dict(unet) + unet_lora_state_dict = get_peft_model_state_dict(unet) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two) + + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one) + text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, - unet_lora_layers=unet_lora_layers, + unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, )