Skip to content

Commit 67cf044

Browse files
authored
Fix to apply LoRAXFormersAttnProcessor instead of LoRAAttnProcessor when xFormers is enabled (#3556)
* fix to use LoRAXFormersAttnProcessor * add test * using new LoraLoaderMixin.save_lora_weights * add test_lora_save_load_with_xformers
1 parent 352ca31 commit 67cf044

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

src/diffusers/loaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
CustomDiffusionXFormersAttnProcessor,
2828
LoRAAttnAddedKVProcessor,
2929
LoRAAttnProcessor,
30+
LoRAXFormersAttnProcessor,
3031
SlicedAttnAddedKVProcessor,
32+
XFormersAttnProcessor,
3133
)
3234
from .utils import (
3335
DIFFUSERS_CACHE,
@@ -279,7 +281,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
279281
attn_processor_class = LoRAAttnAddedKVProcessor
280282
else:
281283
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
282-
attn_processor_class = LoRAAttnProcessor
284+
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
285+
attn_processor_class = LoRAXFormersAttnProcessor
286+
else:
287+
attn_processor_class = LoRAAttnProcessor
283288

284289
attn_processors[key] = attn_processor_class(
285290
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank

tests/models/test_lora_layers.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222

2323
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
2424
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
25-
from diffusers.models.attention_processor import LoRAAttnProcessor
25+
from diffusers.models.attention_processor import (
26+
Attention,
27+
AttnProcessor,
28+
AttnProcessor2_0,
29+
LoRAAttnProcessor,
30+
LoRAXFormersAttnProcessor,
31+
XFormersAttnProcessor,
32+
)
2633
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
2734

2835

@@ -212,3 +219,90 @@ def test_lora_save_load_legacy(self):
212219

213220
# Outputs shouldn't match.
214221
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
222+
223+
def create_lora_weight_file(self, tmpdirname):
224+
_, lora_components = self.get_dummy_components()
225+
LoraLoaderMixin.save_lora_weights(
226+
save_directory=tmpdirname,
227+
unet_lora_layers=lora_components["unet_lora_layers"],
228+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
229+
)
230+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
231+
232+
def test_lora_unet_attn_processors(self):
233+
with tempfile.TemporaryDirectory() as tmpdirname:
234+
self.create_lora_weight_file(tmpdirname)
235+
236+
pipeline_components, _ = self.get_dummy_components()
237+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
238+
sd_pipe = sd_pipe.to(torch_device)
239+
sd_pipe.set_progress_bar_config(disable=None)
240+
241+
# check if vanilla attention processors are used
242+
for _, module in sd_pipe.unet.named_modules():
243+
if isinstance(module, Attention):
244+
self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0))
245+
246+
# load LoRA weight file
247+
sd_pipe.load_lora_weights(tmpdirname)
248+
249+
# check if lora attention processors are used
250+
for _, module in sd_pipe.unet.named_modules():
251+
if isinstance(module, Attention):
252+
self.assertIsInstance(module.processor, LoRAAttnProcessor)
253+
254+
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
255+
def test_lora_unet_attn_processors_with_xformers(self):
256+
with tempfile.TemporaryDirectory() as tmpdirname:
257+
self.create_lora_weight_file(tmpdirname)
258+
259+
pipeline_components, _ = self.get_dummy_components()
260+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
261+
sd_pipe = sd_pipe.to(torch_device)
262+
sd_pipe.set_progress_bar_config(disable=None)
263+
264+
# enable XFormers
265+
sd_pipe.enable_xformers_memory_efficient_attention()
266+
267+
# check if xFormers attention processors are used
268+
for _, module in sd_pipe.unet.named_modules():
269+
if isinstance(module, Attention):
270+
self.assertIsInstance(module.processor, XFormersAttnProcessor)
271+
272+
# load LoRA weight file
273+
sd_pipe.load_lora_weights(tmpdirname)
274+
275+
# check if lora attention processors are used
276+
for _, module in sd_pipe.unet.named_modules():
277+
if isinstance(module, Attention):
278+
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
279+
280+
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
281+
def test_lora_save_load_with_xformers(self):
282+
pipeline_components, lora_components = self.get_dummy_components()
283+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
284+
sd_pipe = sd_pipe.to(torch_device)
285+
sd_pipe.set_progress_bar_config(disable=None)
286+
287+
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
288+
289+
# enable XFormers
290+
sd_pipe.enable_xformers_memory_efficient_attention()
291+
292+
original_images = sd_pipe(**pipeline_inputs).images
293+
orig_image_slice = original_images[0, -3:, -3:, -1]
294+
295+
with tempfile.TemporaryDirectory() as tmpdirname:
296+
LoraLoaderMixin.save_lora_weights(
297+
save_directory=tmpdirname,
298+
unet_lora_layers=lora_components["unet_lora_layers"],
299+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
300+
)
301+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
302+
sd_pipe.load_lora_weights(tmpdirname)
303+
304+
lora_images = sd_pipe(**pipeline_inputs).images
305+
lora_image_slice = lora_images[0, -3:, -3:, -1]
306+
307+
# Outputs shouldn't match.
308+
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

0 commit comments

Comments
 (0)