Skip to content

Conversation

@stevenjlm
Copy link
Contributor

@stevenjlm stevenjlm commented Apr 4, 2024

What does this PR do?

This PR skips scaling LoRA modules during the forward unet step if the lora scale is 1.0 (and thus will have no effect downstream). In profiling tests, I have found that for SDXL loaded with LoRAs, a substantial amount of inference times is spent looping through modules in the scale_lora_layers and unscale_lora_layers methods. If the LoRA scale is 1.0, this loop will have no effect and we might as well skip it.

There are additional details on this at the bottom of this description in the "performance details" section.

Before submitting

(It's a small enough change, I'm not sure it warrants doc or test updates, but I'll be happy to if requested.)

Who can review?

@sayakpaul @yiyixuxu @DN6

Performance Details

Below are the results from using cProfile, and at the bottom is a minimal code snippet I used for these profiles.

Profiler before code change:
image
After code change:
image
And output looks similar.

from cProfile import Profile
from datetime import datetime
from io import BytesIO

import requests
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLImg2ImgPipeline
from PIL import Image
from progressbar import progressbar

profiler = Profile()

MODEL_CACHE = "diffusers-cache"
FUSE = True
lora_ids = {
    "ikea": "ostris/ikea-instructions-lora-sdxl",
}

lora_keywords = {
    "ikea": "ikea",
}

# ------------------------------------- Load base model
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    # "stablediffusionapi/juggernaut-xl-v8",
    cache_dir=MODEL_CACHE,
    torch_dtype=torch.float16,
).to("cuda")

# ------------------------------------- Load Loras
for lora_name, lora_id in progressbar(lora_ids.items()):
    state_dict, _ = pipe.lora_state_dict(
        lora_id,
        unet_config=pipe.unet.config,
    )
    pipe.load_lora_weights(
        state_dict,
        unet=pipe.unet,
        adapter_name=lora_name,
    )

# ------------------------------------- Run Inference
profiler.enable()
in_url = "https://media.cnn.com/api/v1/images/stellar/prod/200605082916-01-real-tiger-king-siberian.jpg"
prompt = "tiger"
lora_name = "ikea"
num_samples = 1

if FUSE:
    pipe.fuse_lora()

pipe.scheduler = DPMSolverMultistepScheduler(use_karras_sigmas=True, algorithm_type="dpmsolver++")
pipe.set_adapters(lora_name)
pipe.fuse_lora()
response = requests.get(in_url)
og_image = Image.open(BytesIO(response.content))

output = pipe(
    image=og_image.convert("RGB"),
    prompt=[lora_keywords[lora_name]] * num_samples,
    num_inference_steps=20,
    generator=torch.Generator("cuda").manual_seed(42),
    strength=0.5,
)

if FUSE:
    pipe.unfuse_lora()

all_images = output.images
output_paths = []

profiler.disable()
profiler.dump_stats(
    f"profile-{datetime.now().strftime('%Y%m%d-%H%M%S')}.prof"
)

for i, sample in enumerate(all_images):
    output_path = f"out-{i}.png"
    sample.save(output_path)

@sayakpaul
Copy link
Member

Thanks for your PR. Could you quantify the time difference?

Cc: @BenjaminBossan here.

@BenjaminBossan
Copy link
Member

Thanks for investigating this issue. Indeed, scaling is unnecessary if the scale is 1 -- in fact, we already have a check in PEFT that skips the scaling in that case. The issue seems to be the looping over the modules and the isinstance check for each module, which, as correctly stated, can be skipped.

My suggestion for this PR, however, is to move the skipping logic inside of scale_lora_layers and unscale_lora_layers. The reason is that if those functions are called elsewhere, all these callers benefit from the optimization. Otherwise, they would each need to perform the same check.

Of course, this adds another function call to the stack, but that should be very negligible overall.

Thanks for your PR. Could you quantify the time difference?

If I reed the graph directly, the difference is from 4.6 sec to 2.3 sec.

@sayakpaul
Copy link
Member

Thanks @BenjaminBossan for your comments.

@stevenjlm I will let you address the comments and we can take it from there.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stevenjlm
Copy link
Contributor Author

@BenjaminBossan @sayakpaul thanks for the feedback and guidance! I will move the logic inside scale_lora_layers and unscale_lora_layers.

@stevenjlm
Copy link
Contributor Author

stevenjlm commented Apr 5, 2024

@BenjaminBossan @sayakpaul Moved the logic inside scale_lora_layers and unscale_lora_layers. I also redid the performance check. As expected, the improvement is very similar: It goes from ~4.6 seconds without the changes down to ~2.3 seconds with the changes.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for digging in and working on this performance improvement.

Could you please fix the code quality issues and submit again?

new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module = torch.nn.Linear(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the formatter, when I ran make style it made this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you prefer if I undo it?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much. Just one comment. But I think we’re good to go.

@sayakpaul
Copy link
Member

@stevenjlm could you push an empty commit on your end? I think the failing test is unrelated.

@stevenjlm
Copy link
Contributor Author

Pushed an empty commit, hopefully workflows pass. @sayakpaul

@stevenjlm
Copy link
Contributor Author

I'm looking into this failing test to see if there's anything I can do to fix it..

@stevenjlm
Copy link
Contributor Author

@sayakpaul I see that @yiyixuxu commented out the test that was failing in #7620 so checks should pass now that I rebased.

@sayakpaul
Copy link
Member

Thanks! Please tag me once the CI run is complete.

@stevenjlm
Copy link
Contributor Author

@sayakpaul CI run passed!

@sayakpaul sayakpaul merged commit 42f25d6 into huggingface:main Apr 11, 2024
@sayakpaul
Copy link
Member

Thanks for this cool contribution!

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Skip scaling if scale is identity

* move check for weight one to scale and unscale lora

* fix code style/quality

* Empty-Commit

---------

Co-authored-by: Steven Munn <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Steven Munn <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants