Skip to content

[LoRA] Enabling limited LoRA support for text encoder #2882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 29, 2023

Potentially closes #2719.

The community has shown that using LoRA for fine-tuning both the UNet and the text encoder while performing DreamBooth-like training has been quite effective.

Diffusers supports LoRA for UNet but not for the text encoder (see #2719 for details).

This PR introduces limited LoRA support for the text encoder using monkey patching. Here's the overall API design:

from diffusers.loaders import TextEncoderLoRAMixin
from transformers import CLIPTextModel

def get_text_encoder():
    return CLIPTextModel.from_pretrained(
        "runwayml/stable-diffusion-v1-5", subfolder="text_encoder"
    )

text_encoder = get_text_encoder()

### Initialization

# Register the `text_encoder` as a class member amnogst other things. 
text_encoder_lora_wrapper = TextEncoderLoRAMixin(text_encoder) 
text_encoder_lora_layers = text_encoder_lora_wrapper.text_encoder_lora_layers

### Perform training of `text_encoder_lora_layers`.

### Save `text_encoder_lora_layers`.
text_encoder_lora_wrapper.save_attn_procs(".", text_encoder_lora_layers)

### Load.
text_encoder = text_encoder_lora_wrapper.load_attn_procs(".")

Gotchas to be aware of:

  • We should probably not use LoRA on the out projection and the key projection layers when using text encoder. This was used in the original LoRA work. But not sure what the community prefers. I guess we can revisit this if it becomes problematic.
  • The monkey-patching doesn't use scale used for merging the LoRA parameters with the corresponding text encoder parameters. This is because cross_attention_kwargs wouldn't work with text encoder. If you have any ideas to tackle this issue, let me know.

Note that this PR does not modify the train_dreambooth_lora.py script yet. I want to do that in a separate PR given this PR is merged (after modifications and discussions).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

The text encoder module underlying a [`~DiffusionPipeline`].
"""

def __init__(self, text_encoder: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the __init__ from the Mixin? I think we could call _initialize_lora_layers() in load_attn_procs no?

I'm not a fan of Mixins having inits because this means they can't be "plugged" into the StableDiffusionPipeline class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, mixins should ideally not have __init__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't envision this Mixin to be used with a DiffusionPipeline class.

_initialize_lora_layers() initializes the LoRA parameters, and having it inside load_attn_procs() is not a good choice IMO since syntactically both of them are doing different things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer: I'm here thinking more about inference of text encoder loras then training

I think we should have a function called:

load_lora()

or:

load_lora_weights(...)

That can be called from StableDiffusionPipeline(...)

I don't think we should wrap the text encoder:

text_encoder_lora_wrapper = TextEncoderLoRAMixin(text_encoder)

=> this breaks things for inference
text_encoder_lora_wrapper cannot be passed to the StableDiffusionPipeline because it doesn't have a forward method, it cannot be saved etc...

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the design a lot - it's super cool that we can re-use the AttnProcsLayers class here. I think it would however be nicer if we don't have to save a new weight file for the TextEncoder.

I'd maybe suggest to:

  • Call TextEncoderLoRAMixin just LoraLoaderMixin and inside the LoraLoaderMixin assume that the unet has a UNet2DConditionLoadersMixin
  • Remove the init - I don't think MixinLoaders should have an init(...)
  • Make sure only one file is saved per LoRA. If someone trains both text encoder and unet lora, we only want to have 1 file to be save IMO.

So IMO we should aim for the following API:

pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("path-to-lora")

Now the load_lora_weights is part of the LoraLoaderMixin and it loads the state_dict from "path-to-lora" (either local or Hub or PyTorch state dict). Then passes the unet part of the loaded state dict to UNet2DConditionLoadersMixin.load_attn_procs (Note that this input accepts not just filenames but also PyTorch state_dicts). Then it calls _initialize_lora_layers for the text encoder and finally it loads the text part of the state_dict into the text encoder.

This way we can just plug this Mixin into every pipeline class and don't have to worry about any super().__init__() problems. The only problem here is that the LoraLoaderMixin has to know the name of a) the text encoder and the unet. However we could easily solve this problem with class attributes. E.g. we just give LoraLoaderMixin two class attributes:

class LoRALoderMixin:

text_encoder_name = None
unet_name = None

And those are then overwritten in the StableDiffusionPipeline class (this is a common API that we already use here, e.g.:

config_name = CONFIG_NAME

I think we can just stick to the weight name "pytorch_lora_weights.bin"

Wdyt @sayakpaul ?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool and great that we can leverage existing API!

+1 for Patrricks suggestion. Agree that having one loader class would be better, and this way, the users won't have to worry about using the mixin.

The text encoder module underlying a [`~DiffusionPipeline`].
"""

def __init__(self, text_encoder: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, mixins should ideally not have __init__

@sayakpaul
Copy link
Member Author

Then it calls _initialize_lora_layers for the text encoder and finally it loads the text part of the state_dict into the text encoder.

How are the LoRA layers for the text encoder initialized during the training then? We need to ensure that the users follow the exact approach taken in _initialize_lora_layers() to initialize the LoRA parameters.

vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config).to(torch_device)
text_encoder_lora_wrapper = TextEncoderLoRAMixin(copy.deepcopy(text_encoder))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @patrickvonplaten and @patil-suraj, if this is the way to use the new class then it shouldn't be a mixin.

@pcuenca
Copy link
Member

pcuenca commented Mar 30, 2023

Following up on the discussion, I think we may need to have a TextEncoderLoRAMixin (or helper class, if we can't use a mixin) in addition to a pipeline-level mixin that calls both the text encoder and the UNet loaders. Would that work for training @sayakpaul?

@patrickvonplaten
Copy link
Contributor

I've mostly thought about inference here - things are indeed a bit tricker for training. I think the important part is the inference part though.

How about the following:

  1. We use this design for inference: [LoRA] Enabling limited LoRA support for text encoder #2882 (review) (IMO that's the only design that works nicely)
  2. For training we now we slightly adapt_initialize_lora_layers to accept a text encoder and return the text encoder lora layers and to make it a class method
    Then all we have to do for training is:
from diffusers.loaders import LoRALoaderMixin

...
text_encoder_lora_layers = LoRALoaderMixin.initialize_lora_layers(text_encoder)

=> This would be a pretty nice API also since one only needs to train the LoRA layers and not the whole encoder, so here we've directly seperated trainable weights from non-trainable weights

@sayakpaul
Copy link
Member Author

I think the important part is the inference part though.

There's no inference if there's no training. So, I respectfully disagree.

But that said, I really like what you're proposing overall here.

But this introduces a discrepancy between how we initialize the LoRA layers for the UNet and the text encoder:

lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)

For the UNet, the initialization part is handled explicitly, whereas for the text encoder, we're thinking of having a class method. I don't mind having an explicit initialization for the text encoder LoRA layers as well to keep the flow consistent and simple (over easy).

@sayakpaul
Copy link
Member Author

Closing this because the conflicts are brutal.

Opened #2918.

@sayakpaul sayakpaul closed this Mar 31, 2023
@sayakpaul sayakpaul deleted the feat/lora-text-enc branch April 5, 2023 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[LoRA] allow fine-tuning of the text encoder with LoRA (using peft)
5 participants