Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ def forward(
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
Expand Down
9 changes: 7 additions & 2 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def get_dummy_inputs(self, with_generator=True):

pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"num_inference_steps": 5,
"guidance_scale": 6.0,
"output_type": "np",
}
Expand Down Expand Up @@ -589,7 +589,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
"Lora + scale should change the output",
)

Expand Down Expand Up @@ -1300,6 +1300,11 @@ def test_integration_logits_with_scale(self):
pipe.load_lora_weights(lora_id)
pipe = pipe.to("cuda")

self.assertTrue(
self.check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in UNet",
)

self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder 2",
Expand Down