Skip to content

Commit b45204e

Browse files
authored
Add function to remove monkey-patch for text encoder LoRA (#3649)
* merge undoable-monkeypatch * remove TEXT_ENCODER_TARGET_MODULES, refactoring * move create_lora_weight_file
1 parent a8b0f42 commit b45204e

File tree

4 files changed

+97
-42
lines changed

4 files changed

+97
-42
lines changed

src/diffusers/loaders.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .utils import (
3535
DIFFUSERS_CACHE,
3636
HF_HUB_OFFLINE,
37-
TEXT_ENCODER_TARGET_MODULES,
37+
TEXT_ENCODER_ATTN_MODULE,
3838
_get_model_file,
3939
deprecate,
4040
is_safetensors_available,
@@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self):
955955
return self._text_encoder_lora_attn_procs
956956
return
957957

958+
def _remove_text_encoder_monkey_patch(self):
959+
# Loop over the CLIPAttention module of text_encoder
960+
for name, attn_module in self.text_encoder.named_modules():
961+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
962+
# Loop over the LoRA layers
963+
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
964+
# Retrieve the q/k/v/out projection of CLIPAttention
965+
module = attn_module.get_submodule(text_encoder_attr)
966+
if hasattr(module, "old_forward"):
967+
# restore original `forward` to remove monkey-patch
968+
module.forward = module.old_forward
969+
delattr(module, "old_forward")
970+
958971
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
959972
r"""
960973
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
963976
attn_processors: Dict[str, `LoRAAttnProcessor`]:
964977
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
965978
"""
966-
# Loop over the original attention modules.
967-
for name, _ in self.text_encoder.named_modules():
968-
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
969-
# Retrieve the module and its corresponding LoRA processor.
970-
module = self.text_encoder.get_submodule(name)
971-
# Construct a new function that performs the LoRA merging. We will monkey patch
972-
# this forward pass.
973-
attn_processor_name = ".".join(name.split(".")[:-1])
974-
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
975-
old_forward = module.forward
976-
977-
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
978-
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
979-
def make_new_forward(old_forward, lora_layer):
980-
def new_forward(x):
981-
return old_forward(x) + lora_layer(x)
982-
983-
return new_forward
984-
985-
# Monkey-patch.
986-
module.forward = make_new_forward(old_forward, lora_layer)
987-
988-
def _get_lora_layer_attribute(self, name: str) -> str:
989-
if "q_proj" in name:
990-
return "to_q_lora"
991-
elif "v_proj" in name:
992-
return "to_v_lora"
993-
elif "k_proj" in name:
994-
return "to_k_lora"
995-
else:
996-
return "to_out_lora"
979+
980+
# First, remove any monkey-patch that might have been applied before
981+
self._remove_text_encoder_monkey_patch()
982+
983+
# Loop over the CLIPAttention module of text_encoder
984+
for name, attn_module in self.text_encoder.named_modules():
985+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
986+
# Loop over the LoRA layers
987+
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
988+
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
989+
module = attn_module.get_submodule(text_encoder_attr)
990+
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
991+
992+
# save old_forward to module that can be used to remove monkey-patch
993+
old_forward = module.old_forward = module.forward
994+
995+
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
996+
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
997+
def make_new_forward(old_forward, lora_layer):
998+
def new_forward(x):
999+
return old_forward(x) + lora_layer(x)
1000+
1001+
return new_forward
1002+
1003+
# Monkey-patch.
1004+
module.forward = make_new_forward(old_forward, lora_layer)
1005+
1006+
@property
1007+
def _lora_attn_processor_attr_to_text_encoder_attr(self):
1008+
return {
1009+
"to_q_lora": "q_proj",
1010+
"to_k_lora": "k_proj",
1011+
"to_v_lora": "v_proj",
1012+
"to_out_lora": "out_proj",
1013+
}
9971014

9981015
def _load_text_encoder_attn_procs(
9991016
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs

src/diffusers/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
ONNX_WEIGHTS_NAME,
3232
SAFETENSORS_WEIGHTS_NAME,
3333
TEXT_ENCODER_ATTN_MODULE,
34-
TEXT_ENCODER_TARGET_MODULES,
3534
WEIGHTS_NAME,
3635
)
3736
from .deprecation_utils import deprecate

src/diffusers/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,4 @@
3030
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
3131
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
3232
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33-
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
3433
TEXT_ENCODER_ATTN_MODULE = ".self_attn"

tests/models/test_lora_layers.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ def get_dummy_inputs(self):
163163

164164
return noise, input_ids, pipeline_inputs
165165

166+
def create_lora_weight_file(self, tmpdirname):
167+
_, lora_components = self.get_dummy_components()
168+
LoraLoaderMixin.save_lora_weights(
169+
save_directory=tmpdirname,
170+
unet_lora_layers=lora_components["unet_lora_layers"],
171+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
172+
)
173+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
174+
166175
def test_lora_save_load(self):
167176
pipeline_components, lora_components = self.get_dummy_components()
168177
sd_pipe = StableDiffusionPipeline(**pipeline_components)
@@ -299,14 +308,45 @@ def test_text_encoder_lora_monkey_patch(self):
299308
outputs_without_lora, outputs_with_lora
300309
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"
301310

302-
def create_lora_weight_file(self, tmpdirname):
303-
_, lora_components = self.get_dummy_components()
304-
LoraLoaderMixin.save_lora_weights(
305-
save_directory=tmpdirname,
306-
unet_lora_layers=lora_components["unet_lora_layers"],
307-
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
308-
)
309-
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
311+
def test_text_encoder_lora_remove_monkey_patch(self):
312+
pipeline_components, _ = self.get_dummy_components()
313+
pipe = StableDiffusionPipeline(**pipeline_components)
314+
315+
dummy_tokens = self.get_dummy_tokens()
316+
317+
# inference without lora
318+
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
319+
assert outputs_without_lora.shape == (1, 77, 32)
320+
321+
# create lora_attn_procs with randn up.weights
322+
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
323+
set_lora_up_weights(text_attn_procs, randn_weight=True)
324+
325+
# monkey patch
326+
pipe._modify_text_encoder(text_attn_procs)
327+
328+
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
329+
del text_attn_procs
330+
gc.collect()
331+
332+
# inference with lora
333+
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
334+
assert outputs_with_lora.shape == (1, 77, 32)
335+
336+
assert not torch.allclose(
337+
outputs_without_lora, outputs_with_lora
338+
), "lora outputs should be different to without lora outputs"
339+
340+
# remove monkey patch
341+
pipe._remove_text_encoder_monkey_patch()
342+
343+
# inference with removed lora
344+
outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]
345+
assert outputs_without_lora_removed.shape == (1, 77, 32)
346+
347+
assert torch.allclose(
348+
outputs_without_lora, outputs_without_lora_removed
349+
), "remove lora monkey patch should restore the original outputs"
310350

311351
def test_lora_unet_attn_processors(self):
312352
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)