Skip to content

Add draft for lora text encoder scale #3626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 6, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented May 31, 2023

Draft PR to show how we could correctly deal with LoRA scale for the text encoder

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -839,6 +839,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)

# set lora scale to a reasonable default
self._scale = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we'd want the users to also specify this explicitly if they want to, no?

For that, don't you think exposing an argument for it makes sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, actually users should never set them argument themselves (which is why it's marked private with a _). This should just be used to make it work with LoRA

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we do allow adjusting the scale of the UNet LoRA via cross_attention_kwargs. Refer to this doc:
https://huggingface.co/docs/diffusers/main/en/training/lora

By default, we already use the scale argument (with a value of 1) here:

self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None

By exposing the scale arguments to the users, they have explicit control over how they want to control the effect.

Let me know if I am missing out on something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a misunderstanding here, let's maybe take this offline in a quick call next week :-)

@patrickvonplaten
Copy link
Contributor Author

@sayakpaul the purpose of this PR is to create a mechanism that allows the user to change the lora scale of the text encoder. I've adapted the PR to make it functional and used it with your example of a1111 here:

#!/usr/bin/env python3
from diffusers import StableDiffusionPipeline, KDPM2DiscreteScheduler, StableDiffusionImg2ImgPipeline, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, DDIMScheduler,  DPMSolverMultistepScheduler
import time
import os
import torch

path = "gsdf/Counterfeit-V2.5"

pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_karras_sigmas=True
)
pipe = pipe.to("cuda")

pipe.load_lora_weights(".", weight_name="light_and_shadow.safetensors")

prompt = "masterpiece, best quality, 1girl, at dusk"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
                   "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts")

pipe.enable_xformers_memory_efficient_attention()
images = pipe(prompt=prompt, 
    negative_prompt=negative_prompt, 
    width=512, 
    height=768, 
    num_inference_steps=15, 
    num_images_per_prompt=4,
    cross_attention_kwargs={"scale": 0.5},
    generator=torch.manual_seed(0)
).images

Note how cross_attention_kwargs={"scale": 0.5} now changes both the scale of the text encoder and the unet. Does this PR make more sense now?

@sayakpaul
Copy link
Member

sayakpaul commented Jun 5, 2023

Note how cross_attention_kwargs={"scale": 0.5} now changes both the scale of the text encoder and the unet.

But now, we are assuming users will always want to use the same scale values for the UNet and the text encoder. It actually might not be the case, no?

P.S.: I was (from day one) clear about the scope of the PR. But from your descriptions, I was not sure how you were thinking about how users would pass the scale parameter for the text encoder. I was thinking of separate scale arguments for the UNet and the text encoder LoRAs. I think I indicated that clearly in my comments so far, but apologies if they were not crystal clear.

Comment on lines +265 to +267
If your LoRA parameters involve the UNet as well as the Text Encoder, then passing
`cross_attention_kwargs={"scale": 0.5}` will apply the `scale` value to both the UNet
and the Text Encoder.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@@ -870,6 +870,9 @@ def main(args):
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
# Setting the `_lora_scale` explicitly because we are not using
# `load_lora_weights()`.
temp_pipeline._lora_scale = 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, forward pass with the text encoder won't be possible.

@patrickvonplaten
Copy link
Contributor Author

@sayakpaul, I've now added the LoRALoader to all models where IMO it makes sense. For some pipelines like pix2pix_zero it didn't make much sense so I just added an import so that the copied from still works. IMO it's more important that code stays in sync than to have 100% clean code for this edge case.

Will merge now so that we have results on the slow tests tomorrow

@patrickvonplaten patrickvonplaten merged commit 74fd735 into main Jun 6, 2023
@patrickvonplaten patrickvonplaten deleted the lora_text_encoder_scale branch June 6, 2023 21:47
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add draft for lora text encoder scale

* Improve naming

* fix: training dreambooth lora script.

* Apply suggestions from code review

* Update examples/dreambooth/train_dreambooth_lora.py

* Apply suggestions from code review

* Apply suggestions from code review

* add lora mixin when fit

* add lora mixin when fit

* add lora mixin when fit

* fix more

* fix more

---------

Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Add draft for lora text encoder scale

* Improve naming

* fix: training dreambooth lora script.

* Apply suggestions from code review

* Update examples/dreambooth/train_dreambooth_lora.py

* Apply suggestions from code review

* Apply suggestions from code review

* add lora mixin when fit

* add lora mixin when fit

* add lora mixin when fit

* fix more

* fix more

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants