-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add function to remove monkey-patch for text encoder LoRA #3649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice I very much like the approach here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! @sayakpaul could you have a look? :-)
src/diffusers/loaders.py
Outdated
@@ -955,6 +955,15 @@ def text_encoder_lora_attn_procs(self): | |||
return self._text_encoder_lora_attn_procs | |||
return | |||
|
|||
def _remove_text_encoder_monkey_patch(self): | |||
for name, _ in self.text_encoder.named_modules(): | |||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be TEXT_ENCODER_ATTN_MODULE
? I see this doesn't get reflected in here too:
diffusers/src/diffusers/loaders.py
Line 968 in 523a50a
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): |
I am asking this because when initializing the LoRA layers for the text encoder, we use TEXT_ENCODER_ATTN_MODULE
:
if name.endswith(TEXT_ENCODER_ATTN_MODULE): |
Doesn't this create a disparity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is correct because the target of the monkey-patch is not the CLIPAttention
itself, but its q/k/v/out_proj
attributes. However, as you pointed out, I thought it would be easier to understand if it was aligned with train_dreambooth_lora.py
, so I tried refactoring to remove TEXT_ENCODER_TARGET_MODULES
. 356a46a WDYT?
tests/models/test_lora_layers.py
Outdated
assert torch.allclose( | ||
outputs_without_lora, outputs_without_lora_removed | ||
), "remove lora monkey patch should restore the original outputs" | ||
|
||
def create_lora_weight_file(self, tmpdirname): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're at it, could we rename this test to test_ create_lora_weight_file()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is a utility, not a test code. To avoid confusion, I moved its location to where the utility functions are at the top of the file. 5d7939d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks amazing to me.
Left some comments for clarification, once those are sorted, we can ship this 🚀
@patrickvonplaten @sayakpaul Thanks for the review! I left comments: |
Seems like removing them is the only option as appearantly you can't just load in a new state dict in the LoRA layers. The monkey patched term still remains in effect. Thanks for implementing this! |
@property | ||
def _lora_attn_processor_attr_to_text_encoder_attr(self): | ||
return { | ||
"to_q_lora": "q_proj", | ||
"to_k_lora": "k_proj", | ||
"to_v_lora": "v_proj", | ||
"to_out_lora": "out_proj", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean-up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking awesome!
@sayakpaul Can we revert this? I can no longer load multiple loras. |
…e#3649) * merge undoable-monkeypatch * remove TEXT_ENCODER_TARGET_MODULES, refactoring * move create_lora_weight_file
…e#3649) * merge undoable-monkeypatch * remove TEXT_ENCODER_TARGET_MODULES, refactoring * move create_lora_weight_file
Discussed in #3621.
To solve the above issue, I created
pipe._remove_text_encoder_monkey_patch()
. When the user callspipe.load_lora_weights()
, it automatically removes any already applied monkey-patch.