Skip to content

Commit f778d91

Browse files
committed
Only patch text encoder related stuff when aux state is in
1 parent c2239a4 commit f778d91

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/diffusers/loaders.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11391139
f"{name}.out_proj.lora_linear_layer.down.weight"
11401140
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
11411141

1142-
if text_encoder_lora_state_dict:
1142+
if state_dict_aux:
11431143
for name, _ in text_encoder_aux_modules(text_encoder):
11441144
for direction in ["up", "down"]:
11451145
for layer in ["fc1", "fc2"]:
@@ -1186,13 +1186,24 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11861186
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
11871187
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
11881188

1189-
for _, aux_module in text_encoder_aux_modules(text_encoder):
1190-
if isinstance(aux_module.fc1, PatchedLoraProjection):
1191-
aux_module.fc1 = aux_module.fc1.regular_linear_layer
1192-
aux_module.fc2 = aux_module.fc2.regular_linear_layer
1189+
if getattr(text_encoder, "aux_state_dict_populated", False):
1190+
for _, aux_module in text_encoder_aux_modules(text_encoder):
1191+
if isinstance(aux_module.fc1, PatchedLoraProjection):
1192+
aux_module.fc1 = aux_module.fc1.regular_linear_layer
1193+
aux_module.fc2 = aux_module.fc2.regular_linear_layer
1194+
1195+
text_encoder.aux_state_dict_populated = False
11931196

11941197
@classmethod
1195-
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
1198+
def _modify_text_encoder(
1199+
cls,
1200+
text_encoder,
1201+
lora_scale=1,
1202+
network_alpha=None,
1203+
rank=4,
1204+
dtype=None,
1205+
patch_aux=False,
1206+
):
11961207
r"""
11971208
Monkey-patches the forward passes of attention modules of the text encoder.
11981209
"""
@@ -1223,12 +1234,19 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
12231234
)
12241235
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
12251236

1226-
for _, aux_module in text_encoder_aux_modules(text_encoder):
1227-
aux_module.fc1 = PatchedLoraProjection(aux_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype)
1228-
lora_parameters.extend(aux_module.fc1.lora_linear_layer.parameters())
1237+
if patch_aux:
1238+
for _, aux_module in text_encoder_aux_modules(text_encoder):
1239+
aux_module.fc1 = PatchedLoraProjection(
1240+
aux_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
1241+
)
1242+
lora_parameters.extend(aux_module.fc1.lora_linear_layer.parameters())
1243+
1244+
aux_module.fc2 = PatchedLoraProjection(
1245+
aux_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
1246+
)
1247+
lora_parameters.extend(aux_module.fc2.lora_linear_layer.parameters())
12291248

1230-
aux_module.fc2 = PatchedLoraProjection(aux_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype)
1231-
lora_parameters.extend(aux_module.fc2.lora_linear_layer.parameters())
1249+
text_encoder.aux_state_dict_populated = True
12321250

12331251
return lora_parameters
12341252

src/diffusers/models/lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def forward(self, hidden_states):
4848
return up_hidden_states.to(orig_dtype)
4949

5050

51-
# copied from LoRAConv2dLayer
5251
class LoRAConv2dLayer(nn.Module):
5352
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
5453
super().__init__()

0 commit comments

Comments
 (0)