-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[LoRA] fix cross_attention_kwargs problems and tighten tests
#7388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Cc: @younesbelkada for viz. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Will also wait for @BenjaminBossan to approve it. And then I will proceed. |
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch ! Thanks ! One could also use get to avoid copying the kwargs at each forward !
|
The problem with |
|
ok makes sense ! thanks for explaining ! |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this bug, I think the copy solution is solid.
* debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print
What does this PR do?
First of all, I would like to apologize for not being rigorous enough with #7338. This was actually breaking:
This is because
pop()pops the requested key forever from the underlying dictionary (for the first time) and uses the default value throughout the subsequent calls. Sinceunetwithin aDiffusionPipelineis iteratively called this phenomenon creates a lot of unexpected consequences. As a result, the above-mentioned test fails. Here are thelora_scalevalues:Notice how it is defaulting to 1.0 after the first round of denoising step.
A simple solution is to create a shallow copy of
cross_attention_kwargsso that the original one is left untouched. This is what this PR does.Additionally, you may wonder why the below set of tests PASS?
pytest tests/lora/test_lora_layers_peft.py -k "test_simple_inference_with_text_unet_lora_and_scale"My best guess is that because we use a little too few
num_inference_stepsto validate things. To see if my hunch was right, I increased thenum_inference_stepsto 5 here, and run these tests WITHOUT the changes introduced in this PR (i.e., shallow copy). All of those tests failed. With the changes, they pass.Once this PR is merged, I will take care of making another patch release.
Once again, I am genuinely sorry for the oversight on my end.