Skip to content

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jan 26, 2023

Should allow the following:

from diffusers import StableDiffusionPipeline
import torch

model_path = "sayakpaul/sd-model-finetuned-lora-t4"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")

prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}).images[0]
image.save("pokemon.png")

Also cc @sayakpaul @apolinario

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca @patil-suraj @hysts could you take a look here? In order to use LoRA from the pipeline, we need to allow one to pass cross_attention_kwargs to the pipeline call function.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 26, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Copy link
Contributor

@hysts hysts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@patrickvonplaten patrickvonplaten merged commit 0c39f53 into main Jan 27, 2023
@patrickvonplaten patrickvonplaten deleted the allow_lora_from_pipeline branch January 27, 2023 07:20
@jndietz
Copy link

jndietz commented Feb 15, 2023

With xformers, this appears to return an error in the most recent main dev build.

Traceback (most recent call last):
  File "C:\projects\python-test\pipeline.py", line 19, in <module>
    image = pipe(prompt, negative_prompt=negative_prompt, width=512, height=768, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}).images[0]     
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion.py", line 610, in __call__
    noise_pred = self.unet(
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\models\unet_2d_condition.py", line 580, in forward
    sample, res_samples = downsample_block(
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\models\unet_2d_blocks.py", line 837, in forward
    hidden_states = attn(
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\models\transformer_2d.py", line 265, in forward
    hidden_states = block(
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\models\attention.py", line 291, in forward
    attn_output = self.attn1(
  File "C:\Users\Jared\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\github\huggingface\diffusers\src\diffusers\models\cross_attention.py", line 202, in forward
    return self.processor(
TypeError: XFormersCrossAttnProcessor.__call__() got an unexpected keyword argument 'scale'

@sayakpaul
Copy link
Member

@patrickvonplaten seems similar to #2334 (comment) no?

@patrickvonplaten
Copy link
Contributor Author

Hmm no I think here the wrong XForrmersCrossAtnProcessor is loaded. @jndietz could you open a new issue with a reproducible code snippet? :-)

@thihamin
Copy link

@patrickvonplaten Does the LoRA stay in the pipe after an inference? We keep the pipe on memory and use it for multiple inferences. It seems that the previously loaded LoRA is kept in the memory for the next inference when pipe.unet.load_attn_procs(model_path) is not called. Is it possible to detach the LoRA in the next inference?

Also diffusers only supports one LoRA at a time, right?

@patrickvonplaten
Copy link
Contributor Author

Hey @thihamin,

That's right, the previously loaded LoRA is kept if no new attn processor is loaded.

@kirit93
Copy link

kirit93 commented Mar 7, 2023

Will there be support in Diffusers for 2 or more LoRA's anytime soon?

@patrickvonplaten
Copy link
Contributor Author

Hey @kirit93, could you maybe open a feature request for this? :-)

@kirit93
Copy link

kirit93 commented Mar 8, 2023

#2613
Thanks @patrickvonplaten!

@sonaterai
Copy link

sonaterai commented Mar 8, 2023

Hi, have you found a solution to use a lora safetensor ?
Or else convert it?
All that I found does not work unfortunately

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [LoRA] All to use in inference with pipeline

* [LoRA] allow cross attention kwargs passed to pipeline

* finish
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* [LoRA] All to use in inference with pipeline

* [LoRA] allow cross attention kwargs passed to pipeline

* finish
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants