-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA] Enabling limited LoRA support for text encoder #2882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
The text encoder module underlying a [`~DiffusionPipeline`]. | ||
""" | ||
|
||
def __init__(self, text_encoder: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the __init__
from the Mixin? I think we could call _initialize_lora_layers()
in load_attn_procs
no?
I'm not a fan of Mixins having inits because this means they can't be "plugged" into the StableDiffusionPipeline class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, mixins
should ideally not have __init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't envision this Mixin to be used with a DiffusionPipeline
class.
_initialize_lora_layers()
initializes the LoRA parameters, and having it inside load_attn_procs()
is not a good choice IMO since syntactically both of them are doing different things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disclaimer: I'm here thinking more about inference of text encoder loras then training
I think we should have a function called:
load_lora()
or:
load_lora_weights(...)
That can be called from StableDiffusionPipeline(...)
I don't think we should wrap the text encoder:
text_encoder_lora_wrapper = TextEncoderLoRAMixin(text_encoder)
=> this breaks things for inference
text_encoder_lora_wrapper
cannot be passed to the StableDiffusionPipeline because it doesn't have a forward method, it cannot be saved etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the design a lot - it's super cool that we can re-use the AttnProcsLayers
class here. I think it would however be nicer if we don't have to save a new weight file for the TextEncoder.
I'd maybe suggest to:
- Call
TextEncoderLoRAMixin
justLoraLoaderMixin
and inside theLoraLoaderMixin
assume that the unet has a UNet2DConditionLoadersMixin - Remove the init - I don't think MixinLoaders should have an init(...)
- Make sure only one file is saved per LoRA. If someone trains both text encoder and unet lora, we only want to have 1 file to be save IMO.
So IMO we should aim for the following API:
pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("path-to-lora")
Now the load_lora_weights
is part of the LoraLoaderMixin
and it loads the state_dict from "path-to-lora"
(either local or Hub or PyTorch state dict). Then passes the unet part of the loaded state dict to UNet2DConditionLoadersMixin.load_attn_procs
(Note that this input accepts not just filenames but also PyTorch state_dicts). Then it calls _initialize_lora_layers
for the text encoder and finally it loads the text part of the state_dict into the text encoder.
This way we can just plug this Mixin into every pipeline class and don't have to worry about any super().__init__()
problems. The only problem here is that the LoraLoaderMixin
has to know the name of a) the text encoder and the unet. However we could easily solve this problem with class attributes. E.g. we just give LoraLoaderMixin
two class attributes:
class LoRALoderMixin:
text_encoder_name = None
unet_name = None
And those are then overwritten in the StableDiffusionPipeline class (this is a common API that we already use here, e.g.:
config_name = CONFIG_NAME |
I think we can just stick to the weight name "pytorch_lora_weights.bin"
Wdyt @sayakpaul ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool and great that we can leverage existing API!
+1 for Patrricks suggestion. Agree that having one loader class would be better, and this way, the users won't have to worry about using the mixin.
The text encoder module underlying a [`~DiffusionPipeline`]. | ||
""" | ||
|
||
def __init__(self, text_encoder: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, mixins
should ideally not have __init__
How are the LoRA layers for the text encoder initialized during the training then? We need to ensure that the users follow the exact approach taken in |
vocab_size=1000, | ||
) | ||
text_encoder = CLIPTextModel(text_encoder_config).to(torch_device) | ||
text_encoder_lora_wrapper = TextEncoderLoRAMixin(copy.deepcopy(text_encoder)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @patrickvonplaten and @patil-suraj, if this is the way to use the new class then it shouldn't be a mixin.
Following up on the discussion, I think we may need to have a TextEncoderLoRAMixin (or helper class, if we can't use a mixin) in addition to a pipeline-level mixin that calls both the text encoder and the UNet loaders. Would that work for training @sayakpaul? |
I've mostly thought about inference here - things are indeed a bit tricker for training. I think the important part is the inference part though. How about the following:
from diffusers.loaders import LoRALoaderMixin
...
text_encoder_lora_layers = LoRALoaderMixin.initialize_lora_layers(text_encoder) => This would be a pretty nice API also since one only needs to train the LoRA layers and not the whole encoder, so here we've directly seperated trainable weights from non-trainable weights |
There's no inference if there's no training. So, I respectfully disagree. But that said, I really like what you're proposing overall here. But this introduces a discrepancy between how we initialize the LoRA layers for the UNet and the text encoder: diffusers/examples/dreambooth/train_dreambooth_lora.py Lines 714 to 729 in b202127
For the UNet, the initialization part is handled explicitly, whereas for the text encoder, we're thinking of having a class method. I don't mind having an explicit initialization for the text encoder LoRA layers as well to keep the flow consistent and simple (over easy). |
Closing this because the conflicts are brutal. Opened #2918. |
Potentially closes #2719.
The community has shown that using LoRA for fine-tuning both the UNet and the text encoder while performing DreamBooth-like training has been quite effective.
Diffusers supports LoRA for UNet but not for the text encoder (see #2719 for details).
This PR introduces limited LoRA support for the text encoder using monkey patching. Here's the overall API design:
Gotchas to be aware of:
scale
used for merging the LoRA parameters with the corresponding text encoder parameters. This is becausecross_attention_kwargs
wouldn't work with text encoder. If you have any ideas to tackle this issue, let me know.Note that this PR does not modify the
train_dreambooth_lora.py
script yet. I want to do that in a separate PR given this PR is merged (after modifications and discussions).