Skip to content

Add function to remove monkey-patch for text encoder LoRA #3649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2023

Conversation

takuma104
Copy link
Contributor

Discussed in #3621.

As pointed out by @rvorias in #3490 (comment), if we call pipe.load_lora_weights() twice (with a checkpoint containing text encoder LoRA parameters), monkey-patching becomes recursive.
This should be prevented at all costs.

To solve the above issue, I created pipe._remove_text_encoder_monkey_patch(). When the user calls pipe.load_lora_weights(), it automatically removes any already applied monkey-patch.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 2, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice I very much like the approach here!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! @sayakpaul could you have a look? :-)

@@ -955,6 +955,15 @@ def text_encoder_lora_attn_procs(self):
return self._text_encoder_lora_attn_procs
return

def _remove_text_encoder_monkey_patch(self):
for name, _ in self.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be TEXT_ENCODER_ATTN_MODULE? I see this doesn't get reflected in here too:

if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):

I am asking this because when initializing the LoRA layers for the text encoder, we use TEXT_ENCODER_ATTN_MODULE:

if name.endswith(TEXT_ENCODER_ATTN_MODULE):

Doesn't this create a disparity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is correct because the target of the monkey-patch is not the CLIPAttention itself, but its q/k/v/out_proj attributes. However, as you pointed out, I thought it would be easier to understand if it was aligned with train_dreambooth_lora.py, so I tried refactoring to remove TEXT_ENCODER_TARGET_MODULES. 356a46a WDYT?

assert torch.allclose(
outputs_without_lora, outputs_without_lora_removed
), "remove lora monkey patch should restore the original outputs"

def create_lora_weight_file(self, tmpdirname):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it, could we rename this test to test_ create_lora_weight_file()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a utility, not a test code. To avoid confusion, I moved its location to where the utility functions are at the top of the file. 5d7939d

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks amazing to me.

Left some comments for clarification, once those are sorted, we can ship this 🚀

@takuma104
Copy link
Contributor Author

takuma104 commented Jun 5, 2023

@patrickvonplaten @sayakpaul Thanks for the review! I left comments:
#3649 (comment)
#3649 (comment)

@rvorias
Copy link

rvorias commented Jun 5, 2023

Seems like removing them is the only option as appearantly you can't just load in a new state dict in the LoRA layers. The monkey patched term still remains in effect. Thanks for implementing this!

Comment on lines +1006 to +1013
@property
def _lora_attn_processor_attr_to_text_encoder_attr(self):
return {
"to_q_lora": "q_proj",
"to_k_lora": "k_proj",
"to_v_lora": "v_proj",
"to_out_lora": "out_proj",
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean-up!

@sayakpaul sayakpaul self-requested a review June 6, 2023 08:35
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking awesome!

@sayakpaul sayakpaul merged commit b45204e into huggingface:main Jun 6, 2023
@JemiloII
Copy link

@sayakpaul Can we revert this? I can no longer load multiple loras.

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…e#3649)

* merge undoable-monkeypatch

* remove TEXT_ENCODER_TARGET_MODULES, refactoring

* move create_lora_weight_file
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…e#3649)

* merge undoable-monkeypatch

* remove TEXT_ENCODER_TARGET_MODULES, refactoring

* move create_lora_weight_file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants