Skip to content

Commit 73a91d4

Browse files
authored
Fix to apply LoRAXFormersAttnProcessor instead of LoRAAttnProcessor when xFormers is enabled (huggingface#3556)
* fix to use LoRAXFormersAttnProcessor * add test * using new LoraLoaderMixin.save_lora_weights * add test_lora_save_load_with_xformers
1 parent f8c0586 commit 73a91d4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

loaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
CustomDiffusionXFormersAttnProcessor,
2828
LoRAAttnAddedKVProcessor,
2929
LoRAAttnProcessor,
30+
LoRAXFormersAttnProcessor,
3031
SlicedAttnAddedKVProcessor,
32+
XFormersAttnProcessor,
3133
)
3234
from .utils import (
3335
DIFFUSERS_CACHE,
@@ -279,7 +281,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
279281
attn_processor_class = LoRAAttnAddedKVProcessor
280282
else:
281283
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
282-
attn_processor_class = LoRAAttnProcessor
284+
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
285+
attn_processor_class = LoRAXFormersAttnProcessor
286+
else:
287+
attn_processor_class = LoRAAttnProcessor
283288

284289
attn_processors[key] = attn_processor_class(
285290
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank

0 commit comments

Comments
 (0)