1616import argparse
1717import copy
1818import gc
19- import itertools
2019import logging
2120import math
2221import os
3534from huggingface_hub import create_repo , upload_folder
3635from huggingface_hub .utils import insecure_hashlib
3736from packaging import version
37+ from peft import LoraConfig
38+ from peft .utils import get_peft_model_state_dict
3839from PIL import Image
3940from PIL .ImageOps import exif_transpose
4041from torch .utils .data import Dataset
5253 UNet2DConditionModel ,
5354)
5455from diffusers .loaders import LoraLoaderMixin
55- from diffusers .models .attention_processor import (
56- AttnAddedKVProcessor ,
57- AttnAddedKVProcessor2_0 ,
58- SlicedAttnAddedKVProcessor ,
59- )
60- from diffusers .models .lora import LoRALinearLayer
6156from diffusers .optimization import get_scheduler
62- from diffusers .training_utils import unet_lora_state_dict
6357from diffusers .utils import check_min_version , is_wandb_available
6458from diffusers .utils .import_utils import is_xformers_available
6559
@@ -864,79 +858,19 @@ def main(args):
864858 text_encoder .gradient_checkpointing_enable ()
865859
866860 # now we will add new LoRA weights to the attention layers
867- # It's important to realize here how many attention weights will be added and of which sizes
868- # The sizes of the attention layers consist only of two different variables:
869- # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
870- # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
871-
872- # Let's first see how many attention processors we will have to set.
873- # For Stable Diffusion, it should be equal to:
874- # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
875- # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
876- # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
877- # => 32 layers
878-
879- # Set correct lora layers
880- unet_lora_parameters = []
881- for attn_processor_name , attn_processor in unet .attn_processors .items ():
882- # Parse the attention module.
883- attn_module = unet
884- for n in attn_processor_name .split ("." )[:- 1 ]:
885- attn_module = getattr (attn_module , n )
886-
887- # Set the `lora_layer` attribute of the attention-related matrices.
888- attn_module .to_q .set_lora_layer (
889- LoRALinearLayer (
890- in_features = attn_module .to_q .in_features , out_features = attn_module .to_q .out_features , rank = args .rank
891- )
892- )
893- attn_module .to_k .set_lora_layer (
894- LoRALinearLayer (
895- in_features = attn_module .to_k .in_features , out_features = attn_module .to_k .out_features , rank = args .rank
896- )
897- )
898- attn_module .to_v .set_lora_layer (
899- LoRALinearLayer (
900- in_features = attn_module .to_v .in_features , out_features = attn_module .to_v .out_features , rank = args .rank
901- )
902- )
903- attn_module .to_out [0 ].set_lora_layer (
904- LoRALinearLayer (
905- in_features = attn_module .to_out [0 ].in_features ,
906- out_features = attn_module .to_out [0 ].out_features ,
907- rank = args .rank ,
908- )
909- )
910-
911- # Accumulate the LoRA params to optimize.
912- unet_lora_parameters .extend (attn_module .to_q .lora_layer .parameters ())
913- unet_lora_parameters .extend (attn_module .to_k .lora_layer .parameters ())
914- unet_lora_parameters .extend (attn_module .to_v .lora_layer .parameters ())
915- unet_lora_parameters .extend (attn_module .to_out [0 ].lora_layer .parameters ())
916-
917- if isinstance (attn_processor , (AttnAddedKVProcessor , SlicedAttnAddedKVProcessor , AttnAddedKVProcessor2_0 )):
918- attn_module .add_k_proj .set_lora_layer (
919- LoRALinearLayer (
920- in_features = attn_module .add_k_proj .in_features ,
921- out_features = attn_module .add_k_proj .out_features ,
922- rank = args .rank ,
923- )
924- )
925- attn_module .add_v_proj .set_lora_layer (
926- LoRALinearLayer (
927- in_features = attn_module .add_v_proj .in_features ,
928- out_features = attn_module .add_v_proj .out_features ,
929- rank = args .rank ,
930- )
931- )
932- unet_lora_parameters .extend (attn_module .add_k_proj .lora_layer .parameters ())
933- unet_lora_parameters .extend (attn_module .add_v_proj .lora_layer .parameters ())
861+ unet_lora_config = LoraConfig (
862+ r = args .rank ,
863+ init_lora_weights = "gaussian" ,
864+ target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" , "add_k_proj" , "add_v_proj" ],
865+ )
866+ unet .add_adapter (unet_lora_config )
934867
935- # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
936- # So, instead, we monkey-patch the forward calls of its attention-blocks.
868+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
937869 if args .train_text_encoder :
938- # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
939- text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 , rank = args .rank )
870+ text_lora_config = LoraConfig (
871+ r = args .rank , init_lora_weights = "gaussian" , target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
872+ )
873+ text_encoder .add_adapter (text_lora_config )
940874
941875 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
942876 def save_model_hook (models , weights , output_dir ):
@@ -948,9 +882,9 @@ def save_model_hook(models, weights, output_dir):
948882
949883 for model in models :
950884 if isinstance (model , type (accelerator .unwrap_model (unet ))):
951- unet_lora_layers_to_save = unet_lora_state_dict (model )
885+ unet_lora_layers_to_save = get_peft_model_state_dict (model )
952886 elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
953- text_encoder_lora_layers_to_save = text_encoder_lora_state_dict (model )
887+ text_encoder_lora_layers_to_save = get_peft_model_state_dict (model )
954888 else :
955889 raise ValueError (f"unexpected save model: { model .__class__ } " )
956890
@@ -1010,11 +944,10 @@ def load_model_hook(models, input_dir):
1010944 optimizer_class = torch .optim .AdamW
1011945
1012946 # Optimizer creation
1013- params_to_optimize = (
1014- itertools .chain (unet_lora_parameters , text_lora_parameters )
1015- if args .train_text_encoder
1016- else unet_lora_parameters
1017- )
947+ params_to_optimize = list (filter (lambda p : p .requires_grad , unet .parameters ()))
948+ if args .train_text_encoder :
949+ params_to_optimize = params_to_optimize + list (filter (lambda p : p .requires_grad , text_encoder .parameters ()))
950+
1018951 optimizer = optimizer_class (
1019952 params_to_optimize ,
1020953 lr = args .learning_rate ,
@@ -1257,12 +1190,7 @@ def compute_text_embeddings(prompt):
12571190
12581191 accelerator .backward (loss )
12591192 if accelerator .sync_gradients :
1260- params_to_clip = (
1261- itertools .chain (unet_lora_parameters , text_lora_parameters )
1262- if args .train_text_encoder
1263- else unet_lora_parameters
1264- )
1265- accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1193+ accelerator .clip_grad_norm_ (params_to_optimize , args .max_grad_norm )
12661194 optimizer .step ()
12671195 lr_scheduler .step ()
12681196 optimizer .zero_grad ()
@@ -1385,19 +1313,19 @@ def compute_text_embeddings(prompt):
13851313 if accelerator .is_main_process :
13861314 unet = accelerator .unwrap_model (unet )
13871315 unet = unet .to (torch .float32 )
1388- unet_lora_layers = unet_lora_state_dict (unet )
13891316
1390- if text_encoder is not None and args .train_text_encoder :
1317+ unet_lora_state_dict = get_peft_model_state_dict (unet )
1318+
1319+ if args .train_text_encoder :
13911320 text_encoder = accelerator .unwrap_model (text_encoder )
1392- text_encoder = text_encoder .to (torch .float32 )
1393- text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder )
1321+ text_encoder_state_dict = get_peft_model_state_dict (text_encoder )
13941322 else :
1395- text_encoder_lora_layers = None
1323+ text_encoder_state_dict = None
13961324
13971325 LoraLoaderMixin .save_lora_weights (
13981326 save_directory = args .output_dir ,
1399- unet_lora_layers = unet_lora_layers ,
1400- text_encoder_lora_layers = text_encoder_lora_layers ,
1327+ unet_lora_layers = unet_lora_state_dict ,
1328+ text_encoder_lora_layers = text_encoder_state_dict ,
14011329 )
14021330
14031331 # Final inference
0 commit comments