@@ -1139,7 +1139,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1139
1139
f"{ name } .out_proj.lora_linear_layer.down.weight"
1140
1140
] = text_encoder_lora_state_dict .pop (f"{ name } .to_out_lora.down.weight" )
1141
1141
1142
- if text_encoder_lora_state_dict :
1142
+ if state_dict_aux :
1143
1143
for name , _ in text_encoder_aux_modules (text_encoder ):
1144
1144
for direction in ["up" , "down" ]:
1145
1145
for layer in ["fc1" , "fc2" ]:
@@ -1186,13 +1186,24 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
1186
1186
attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1187
1187
attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1188
1188
1189
- for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1190
- if isinstance (aux_module .fc1 , PatchedLoraProjection ):
1191
- aux_module .fc1 = aux_module .fc1 .regular_linear_layer
1192
- aux_module .fc2 = aux_module .fc2 .regular_linear_layer
1189
+ if getattr (text_encoder , "aux_state_dict_populated" , False ):
1190
+ for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1191
+ if isinstance (aux_module .fc1 , PatchedLoraProjection ):
1192
+ aux_module .fc1 = aux_module .fc1 .regular_linear_layer
1193
+ aux_module .fc2 = aux_module .fc2 .regular_linear_layer
1194
+
1195
+ text_encoder .aux_state_dict_populated = False
1193
1196
1194
1197
@classmethod
1195
- def _modify_text_encoder (cls , text_encoder , lora_scale = 1 , network_alpha = None , rank = 4 , dtype = None ):
1198
+ def _modify_text_encoder (
1199
+ cls ,
1200
+ text_encoder ,
1201
+ lora_scale = 1 ,
1202
+ network_alpha = None ,
1203
+ rank = 4 ,
1204
+ dtype = None ,
1205
+ patch_aux = False ,
1206
+ ):
1196
1207
r"""
1197
1208
Monkey-patches the forward passes of attention modules of the text encoder.
1198
1209
"""
@@ -1223,12 +1234,19 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
1223
1234
)
1224
1235
lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
1225
1236
1226
- for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1227
- aux_module .fc1 = PatchedLoraProjection (aux_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1228
- lora_parameters .extend (aux_module .fc1 .lora_linear_layer .parameters ())
1237
+ if patch_aux :
1238
+ for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1239
+ aux_module .fc1 = PatchedLoraProjection (
1240
+ aux_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype
1241
+ )
1242
+ lora_parameters .extend (aux_module .fc1 .lora_linear_layer .parameters ())
1243
+
1244
+ aux_module .fc2 = PatchedLoraProjection (
1245
+ aux_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype
1246
+ )
1247
+ lora_parameters .extend (aux_module .fc2 .lora_linear_layer .parameters ())
1229
1248
1230
- aux_module .fc2 = PatchedLoraProjection (aux_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1231
- lora_parameters .extend (aux_module .fc2 .lora_linear_layer .parameters ())
1249
+ text_encoder .aux_state_dict_populated = True
1232
1250
1233
1251
return lora_parameters
1234
1252
0 commit comments