Skip to content

Add function to remove monkey-patch for text encoder LoRA #3649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
TEXT_ENCODER_TARGET_MODULES,
TEXT_ENCODER_ATTN_MODULE,
_get_model_file,
deprecate,
is_safetensors_available,
Expand Down Expand Up @@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self):
return self._text_encoder_lora_attn_procs
return

def _remove_text_encoder_monkey_patch(self):
# Loop over the CLIPAttention module of text_encoder
for name, attn_module in self.text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
# Loop over the LoRA layers
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
# Retrieve the q/k/v/out projection of CLIPAttention
module = attn_module.get_submodule(text_encoder_attr)
if hasattr(module, "old_forward"):
# restore original `forward` to remove monkey-patch
module.forward = module.old_forward
delattr(module, "old_forward")

def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
Expand All @@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
attn_processors: Dict[str, `LoRAAttnProcessor`]:
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
"""
# Loop over the original attention modules.
for name, _ in self.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
# Retrieve the module and its corresponding LoRA processor.
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward

# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
return "to_q_lora"
elif "v_proj" in name:
return "to_v_lora"
elif "k_proj" in name:
return "to_k_lora"
else:
return "to_out_lora"

# First, remove any monkey-patch that might have been applied before
self._remove_text_encoder_monkey_patch()

# Loop over the CLIPAttention module of text_encoder
for name, attn_module in self.text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
# Loop over the LoRA layers
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
module = attn_module.get_submodule(text_encoder_attr)
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)

# save old_forward to module that can be used to remove monkey-patch
old_forward = module.old_forward = module.forward

# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
module.forward = make_new_forward(old_forward, lora_layer)

@property
def _lora_attn_processor_attr_to_text_encoder_attr(self):
return {
"to_q_lora": "q_proj",
"to_k_lora": "k_proj",
"to_v_lora": "v_proj",
"to_out_lora": "out_proj",
}
Comment on lines +1006 to +1013
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean-up!


def _load_text_encoder_attn_procs(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
56 changes: 48 additions & 8 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def get_dummy_inputs(self):

return noise, input_ids, pipeline_inputs

def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))

def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
Expand Down Expand Up @@ -299,14 +308,45 @@ def test_text_encoder_lora_monkey_patch(self):
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"

def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
def test_text_encoder_lora_remove_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)

dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)

# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora outputs should be different to without lora outputs"

# remove monkey patch
pipe._remove_text_encoder_monkey_patch()

# inference with removed lora
outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora_removed.shape == (1, 77, 32)

assert torch.allclose(
outputs_without_lora, outputs_without_lora_removed
), "remove lora monkey patch should restore the original outputs"

def test_lora_unet_attn_processors(self):
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down