49
49
StableDiffusionPipeline ,
50
50
UNet2DConditionModel ,
51
51
)
52
- from diffusers .loaders import AttnProcsLayers , LoraLoaderMixin , text_encoder_attn_modules
52
+ from diffusers .loaders import (
53
+ LORA_WEIGHT_NAME ,
54
+ TEXT_ENCODER_NAME ,
55
+ UNET_NAME ,
56
+ LoraLoaderMixin ,
57
+ text_encoder_attn_modules ,
58
+ text_encoder_lora_state_dict ,
59
+ )
53
60
from diffusers .models .attention_processor import (
54
61
AttnAddedKVProcessor ,
55
62
AttnAddedKVProcessor2_0 ,
@@ -832,6 +839,7 @@ def main(args):
832
839
833
840
# Set correct lora layers
834
841
unet_lora_attn_procs = {}
842
+ unet_lora_parameters = []
835
843
for name , attn_processor in unet .attn_processors .items ():
836
844
cross_attention_dim = None if name .endswith ("attn1.processor" ) else unet .config .cross_attention_dim
837
845
if name .startswith ("mid_block" ):
@@ -849,18 +857,17 @@ def main(args):
849
857
lora_attn_processor_class = (
850
858
LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
851
859
)
852
- unet_lora_attn_procs [ name ] = lora_attn_processor_class (
853
- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
854
- )
860
+ module = lora_attn_processor_class (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
861
+ unet_lora_attn_procs [ name ] = module
862
+ unet_lora_parameters . append ( module . parameters () )
855
863
856
864
unet .set_attn_processor (unet_lora_attn_procs )
857
- unet_lora_layers = AttnProcsLayers (unet .attn_processors )
858
865
859
866
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
860
867
# So, instead, we monkey-patch the forward calls of its attention-blocks.
861
- text_encoder_lora_layers = None
862
868
if args .train_text_encoder :
863
869
text_lora_attn_procs = {}
870
+ text_lora_parameters = []
864
871
865
872
for name , module in text_encoder_attn_modules (text_encoder ):
866
873
if isinstance (text_encoder , CLIPTextModel ):
@@ -872,9 +879,10 @@ def main(args):
872
879
else :
873
880
raise ValueError (f"{ text_encoder .__class__ .__name__ } does not support LoRA training" )
874
881
875
- text_lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , inner_dim = inner_dim )
882
+ module = LoRAAttnProcessor (hidden_size = hidden_size , inner_dim = inner_dim )
883
+ text_lora_attn_procs [name ] = module
884
+ text_lora_parameters .append (module .parameters ())
876
885
877
- text_encoder_lora_layers = AttnProcsLayers (text_lora_attn_procs )
878
886
LoraLoaderMixin ._modify_text_encoder (text_lora_attn_procs , text_encoder )
879
887
880
888
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -884,23 +892,13 @@ def save_model_hook(models, weights, output_dir):
884
892
unet_lora_layers_to_save = None
885
893
text_encoder_lora_layers_to_save = None
886
894
887
- if args .train_text_encoder :
888
- text_encoder_keys = accelerator .unwrap_model (text_encoder_lora_layers ).state_dict ().keys ()
889
- unet_keys = accelerator .unwrap_model (unet_lora_layers ).state_dict ().keys ()
890
-
891
895
for model in models :
892
- state_dict = model .state_dict ()
893
-
894
- if (
895
- text_encoder_lora_layers is not None
896
- and text_encoder_keys is not None
897
- and state_dict .keys () == text_encoder_keys
898
- ):
899
- # text encoder
900
- text_encoder_lora_layers_to_save = state_dict
901
- elif state_dict .keys () == unet_keys :
902
- # unet
903
- unet_lora_layers_to_save = state_dict
896
+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
897
+ unet_lora_layers_to_save = model .attn_processors_state_dict
898
+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
899
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict (model )
900
+ else :
901
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
904
902
905
903
# make sure to pop weight so that corresponding model is not saved again
906
904
weights .pop ()
@@ -912,27 +910,23 @@ def save_model_hook(models, weights, output_dir):
912
910
)
913
911
914
912
def load_model_hook (models , input_dir ):
915
- # Note we DON'T pass the unet and text encoder here an purpose
916
- # so that the we don't accidentally override the LoRA layers of
917
- # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
918
- # with new torch.nn.Modules / weights. We simply use the pipeline class as
919
- # an easy way to load the lora checkpoints
920
- temp_pipeline = DiffusionPipeline .from_pretrained (
921
- args .pretrained_model_name_or_path ,
922
- revision = args .revision ,
923
- torch_dtype = weight_dtype ,
924
- )
925
- temp_pipeline .load_lora_weights (input_dir )
926
-
927
- # load lora weights into models
928
- models [0 ].load_state_dict (AttnProcsLayers (temp_pipeline .unet .attn_processors ).state_dict ())
929
- if len (models ) > 1 :
930
- models [1 ].load_state_dict (AttnProcsLayers (temp_pipeline .text_encoder_lora_attn_procs ).state_dict ())
913
+ lora_weights = torch .load (os .path .join (input_dir , LORA_WEIGHT_NAME ))
914
+ unet_weights = {}
915
+ text_encoder_weights = {}
916
+
917
+ for k , v in lora_weights .items ():
918
+ model , * k = k .split ("." )
919
+ k = "." .join (k )
920
+
921
+ if model == UNET_NAME :
922
+ unet_weights [k ] = v
923
+ elif model == TEXT_ENCODER_NAME :
924
+ text_encoder_weights [k ] = v
925
+ else :
926
+ raise ValueError (f"unknown model name { model } " )
931
927
932
- # delete temporary pipeline and pop models
933
- del temp_pipeline
934
- for _ in range (len (models )):
935
- models .pop ()
928
+ unet .load_state_dict (unet_weights , strict = False )
929
+ text_encoder .load_state_dict (text_encoder_weights , strict = False )
936
930
937
931
accelerator .register_save_state_pre_hook (save_model_hook )
938
932
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -962,9 +956,9 @@ def load_model_hook(models, input_dir):
962
956
963
957
# Optimizer creation
964
958
params_to_optimize = (
965
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
959
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
966
960
if args .train_text_encoder
967
- else unet_lora_layers . parameters ()
961
+ else unet_lora_parameters
968
962
)
969
963
optimizer = optimizer_class (
970
964
params_to_optimize ,
@@ -1053,12 +1047,12 @@ def compute_text_embeddings(prompt):
1053
1047
1054
1048
# Prepare everything with our `accelerator`.
1055
1049
if args .train_text_encoder :
1056
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1057
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler
1050
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1051
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler
1058
1052
)
1059
1053
else :
1060
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1061
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler
1054
+ unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1055
+ unet , optimizer , train_dataloader , lr_scheduler
1062
1056
)
1063
1057
1064
1058
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1207,9 +1201,9 @@ def compute_text_embeddings(prompt):
1207
1201
accelerator .backward (loss )
1208
1202
if accelerator .sync_gradients :
1209
1203
params_to_clip = (
1210
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
1204
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
1211
1205
if args .train_text_encoder
1212
- else unet_lora_layers . parameters ()
1206
+ else unet_lora_parameters
1213
1207
)
1214
1208
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1215
1209
optimizer .step ()
@@ -1309,12 +1303,14 @@ def compute_text_embeddings(prompt):
1309
1303
# Save the lora layers
1310
1304
accelerator .wait_for_everyone ()
1311
1305
if accelerator .is_main_process :
1306
+ unet = accelerator .unwrap_model (unet )
1312
1307
unet = unet .to (torch .float32 )
1313
- unet_lora_layers = accelerator . unwrap_model ( unet_lora_layers )
1308
+ unet_lora_layers = unet . attn_processors_state_dict
1314
1309
1315
- if text_encoder is not None :
1310
+ if text_encoder is not None and args .train_text_encoder :
1311
+ text_encoder = accelerator .unwrap_model (text_encoder )
1316
1312
text_encoder = text_encoder .to (torch .float32 )
1317
- text_encoder_lora_layers = accelerator . unwrap_model ( text_encoder_lora_layers )
1313
+ text_encoder_lora_layers = text_encoder_lora_state_dict ( text_encoder )
1318
1314
1319
1315
LoraLoaderMixin .save_lora_weights (
1320
1316
save_directory = args .output_dir ,
0 commit comments