@@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
72
72
self .mapping = dict (enumerate (state_dict .keys ()))
73
73
self .rev_mapping = {v : k for k , v in enumerate (state_dict .keys ())}
74
74
75
- # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
76
- self .split_keys = [".processor" , ".k_proj" , ".q_proj" , ".v_proj" , ".out_proj " ]
75
+ # .processor for unet, .self_attn for text encoder
76
+ self .split_keys = [".processor" , ".self_attn " ]
77
77
78
78
# we add a hook to state_dict() and load_state_dict() so that the
79
79
# naming fits with `unet.attn_processors`
@@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
182
182
subfolder = kwargs .pop ("subfolder" , None )
183
183
weight_name = kwargs .pop ("weight_name" , None )
184
184
use_safetensors = kwargs .pop ("use_safetensors" , None )
185
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
186
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
187
+ network_alpha = kwargs .pop ("network_alpha" , None )
185
188
186
189
if use_safetensors and not is_safetensors_available ():
187
190
raise ValueError (
@@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
287
290
attn_processor_class = LoRAAttnProcessor
288
291
289
292
attn_processors [key ] = attn_processor_class (
290
- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim , rank = rank
293
+ hidden_size = hidden_size ,
294
+ cross_attention_dim = cross_attention_dim ,
295
+ rank = rank ,
296
+ network_alpha = network_alpha ,
291
297
)
292
298
attn_processors [key ].load_state_dict (value_dict )
293
299
elif is_custom_diffusion :
@@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
774
780
775
781
<Tip warning={true}>
776
782
783
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
784
+
777
785
This function is experimental and might change in the future.
778
786
779
787
</Tip>
@@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898
906
else :
899
907
state_dict = pretrained_model_name_or_path_or_dict
900
908
909
+ # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
910
+ network_alpha = None
911
+ if all ((k .startswith ("lora_te_" ) or k .startswith ("lora_unet_" )) for k in state_dict .keys ()):
912
+ state_dict , network_alpha = self ._convert_kohya_lora_to_diffusers (state_dict )
913
+
901
914
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
902
915
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
903
916
# their prefixes.
@@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
909
922
unet_lora_state_dict = {
910
923
k .replace (f"{ self .unet_name } ." , "" ): v for k , v in state_dict .items () if k in unet_keys
911
924
}
912
- self .unet .load_attn_procs (unet_lora_state_dict )
925
+ self .unet .load_attn_procs (unet_lora_state_dict , network_alpha = network_alpha )
913
926
914
927
# Load the layers corresponding to text encoder and make necessary adjustments.
915
928
text_encoder_keys = [k for k in keys if k .startswith (self .text_encoder_name )]
@@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
918
931
k .replace (f"{ self .text_encoder_name } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
919
932
}
920
933
if len (text_encoder_lora_state_dict ) > 0 :
921
- attn_procs_text_encoder = self ._load_text_encoder_attn_procs (text_encoder_lora_state_dict )
934
+ attn_procs_text_encoder = self ._load_text_encoder_attn_procs (
935
+ text_encoder_lora_state_dict , network_alpha = network_alpha
936
+ )
922
937
self ._modify_text_encoder (attn_procs_text_encoder )
923
938
924
939
# save lora attn procs of text encoder so that it can be easily retrieved
@@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
954
969
module = self .text_encoder .get_submodule (name )
955
970
# Construct a new function that performs the LoRA merging. We will monkey patch
956
971
# this forward pass.
957
- lora_layer = getattr (attn_processors [name ], self ._get_lora_layer_attribute (name ))
972
+ attn_processor_name = "." .join (name .split ("." )[:- 1 ])
973
+ lora_layer = getattr (attn_processors [attn_processor_name ], self ._get_lora_layer_attribute (name ))
958
974
old_forward = module .forward
959
975
960
- def new_forward (x ):
961
- return old_forward (x ) + lora_layer (x )
976
+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
977
+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
978
+ def make_new_forward (old_forward , lora_layer ):
979
+ def new_forward (x ):
980
+ return old_forward (x ) + lora_layer (x )
981
+
982
+ return new_forward
962
983
963
984
# Monkey-patch.
964
- module .forward = new_forward
985
+ module .forward = make_new_forward ( old_forward , lora_layer )
965
986
966
987
def _get_lora_layer_attribute (self , name : str ) -> str :
967
988
if "q_proj" in name :
@@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs(
1048
1069
subfolder = kwargs .pop ("subfolder" , None )
1049
1070
weight_name = kwargs .pop ("weight_name" , None )
1050
1071
use_safetensors = kwargs .pop ("use_safetensors" , None )
1072
+ network_alpha = kwargs .pop ("network_alpha" , None )
1051
1073
1052
1074
if use_safetensors and not is_safetensors_available ():
1053
1075
raise ValueError (
@@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs(
1125
1147
hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
1126
1148
1127
1149
attn_processors [key ] = LoRAAttnProcessor (
1128
- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim , rank = rank
1150
+ hidden_size = hidden_size ,
1151
+ cross_attention_dim = cross_attention_dim ,
1152
+ rank = rank ,
1153
+ network_alpha = network_alpha ,
1129
1154
)
1130
1155
attn_processors [key ].load_state_dict (value_dict )
1131
1156
@@ -1219,6 +1244,56 @@ def save_function(weights, filename):
1219
1244
save_function (state_dict , os .path .join (save_directory , weight_name ))
1220
1245
logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
1221
1246
1247
+ def _convert_kohya_lora_to_diffusers (self , state_dict ):
1248
+ unet_state_dict = {}
1249
+ te_state_dict = {}
1250
+ network_alpha = None
1251
+
1252
+ for key , value in state_dict .items ():
1253
+ if "lora_down" in key :
1254
+ lora_name = key .split ("." )[0 ]
1255
+ lora_name_up = lora_name + ".lora_up.weight"
1256
+ lora_name_alpha = lora_name + ".alpha"
1257
+ if lora_name_alpha in state_dict :
1258
+ alpha = state_dict [lora_name_alpha ].item ()
1259
+ if network_alpha is None :
1260
+ network_alpha = alpha
1261
+ elif network_alpha != alpha :
1262
+ raise ValueError ("Network alpha is not consistent" )
1263
+
1264
+ if lora_name .startswith ("lora_unet_" ):
1265
+ diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
1266
+ diffusers_name = diffusers_name .replace ("down.blocks" , "down_blocks" )
1267
+ diffusers_name = diffusers_name .replace ("mid.block" , "mid_block" )
1268
+ diffusers_name = diffusers_name .replace ("up.blocks" , "up_blocks" )
1269
+ diffusers_name = diffusers_name .replace ("transformer.blocks" , "transformer_blocks" )
1270
+ diffusers_name = diffusers_name .replace ("to.q.lora" , "to_q_lora" )
1271
+ diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
1272
+ diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
1273
+ diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
1274
+ if "transformer_blocks" in diffusers_name :
1275
+ if "attn1" in diffusers_name or "attn2" in diffusers_name :
1276
+ diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
1277
+ diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
1278
+ unet_state_dict [diffusers_name ] = value
1279
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1280
+ elif lora_name .startswith ("lora_te_" ):
1281
+ diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
1282
+ diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
1283
+ diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
1284
+ diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
1285
+ diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
1286
+ diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
1287
+ diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
1288
+ if "self_attn" in diffusers_name :
1289
+ te_state_dict [diffusers_name ] = value
1290
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1291
+
1292
+ unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
1293
+ te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
1294
+ new_state_dict = {** unet_state_dict , ** te_state_dict }
1295
+ return new_state_dict , network_alpha
1296
+
1222
1297
1223
1298
class FromCkptMixin :
1224
1299
"""This helper class allows to directly load .ckpt stable diffusion file_extension
0 commit comments