Skip to content

Commit 5d4f59e

Browse files
authored
Fix running LoRA with xformers (#2286)
* Fix running LoRA with xformers * support disabling xformers * reformat * Add test
1 parent f2eae16 commit 5d4f59e

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def __init__(
105105
def set_use_memory_efficient_attention_xformers(
106106
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
107107
):
108+
is_lora = hasattr(self, "processor") and isinstance(
109+
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
110+
)
111+
108112
if use_memory_efficient_attention_xformers:
109113
if self.added_kv_proj_dim is not None:
110114
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,9 +142,28 @@ def set_use_memory_efficient_attention_xformers(
138142
except Exception as e:
139143
raise e
140144

141-
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
145+
if is_lora:
146+
processor = LoRAXFormersCrossAttnProcessor(
147+
hidden_size=self.processor.hidden_size,
148+
cross_attention_dim=self.processor.cross_attention_dim,
149+
rank=self.processor.rank,
150+
attention_op=attention_op,
151+
)
152+
processor.load_state_dict(self.processor.state_dict())
153+
processor.to(self.processor.to_q_lora.up.weight.device)
154+
else:
155+
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
142156
else:
143-
processor = CrossAttnProcessor()
157+
if is_lora:
158+
processor = LoRACrossAttnProcessor(
159+
hidden_size=self.processor.hidden_size,
160+
cross_attention_dim=self.processor.cross_attention_dim,
161+
rank=self.processor.rank,
162+
)
163+
processor.load_state_dict(self.processor.state_dict())
164+
processor.to(self.processor.to_q_lora.up.weight.device)
165+
else:
166+
processor = CrossAttnProcessor()
144167

145168
self.set_processor(processor)
146169

@@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
324347
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
325348
super().__init__()
326349

350+
self.hidden_size = hidden_size
351+
self.cross_attention_dim = cross_attention_dim
352+
self.rank = rank
353+
327354
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
328355
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
329356
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
@@ -437,9 +464,14 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
437464

438465

439466
class LoRAXFormersCrossAttnProcessor(nn.Module):
440-
def __init__(self, hidden_size, cross_attention_dim, rank=4):
467+
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
441468
super().__init__()
442469

470+
self.hidden_size = hidden_size
471+
self.cross_attention_dim = cross_attention_dim
472+
self.rank = rank
473+
self.attention_op = attention_op
474+
443475
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
444476
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
445477
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
@@ -462,7 +494,9 @@ def __call__(
462494
key = attn.head_to_batch_dim(key).contiguous()
463495
value = attn.head_to_batch_dim(value).contiguous()
464496

465-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
497+
hidden_states = xformers.ops.memory_efficient_attention(
498+
query, key, value, attn_bias=attention_mask, op=self.attention_op
499+
)
466500
hidden_states = attn.batch_to_head_dim(hidden_states)
467501

468502
# linear proj
@@ -595,4 +629,6 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
595629
SlicedAttnProcessor,
596630
CrossAttnAddedKVProcessor,
597631
SlicedAttnAddedKVProcessor,
632+
LoRACrossAttnProcessor,
633+
LoRAXFormersCrossAttnProcessor,
598634
]

tests/models/test_models_unet_2d_condition.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,35 @@ def test_lora_on_off(self):
412412
assert (sample - new_sample).abs().max() < 1e-4
413413
assert (sample - old_sample).abs().max() < 1e-4
414414

415+
@unittest.skipIf(
416+
torch_device != "cuda" or not is_xformers_available(),
417+
reason="XFormers attention is only available with CUDA and `xformers` installed",
418+
)
419+
def test_lora_xformers_on_off(self):
420+
# enable deterministic behavior for gradient checkpointing
421+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
422+
423+
init_dict["attention_head_dim"] = (8, 16)
424+
425+
torch.manual_seed(0)
426+
model = self.model_class(**init_dict)
427+
model.to(torch_device)
428+
lora_attn_procs = create_lora_layers(model)
429+
model.set_attn_processor(lora_attn_procs)
430+
431+
# default
432+
with torch.no_grad():
433+
sample = model(**inputs_dict).sample
434+
435+
model.enable_xformers_memory_efficient_attention()
436+
on_sample = model(**inputs_dict).sample
437+
438+
model.disable_xformers_memory_efficient_attention()
439+
off_sample = model(**inputs_dict).sample
440+
441+
assert (sample - on_sample).abs().max() < 1e-4
442+
assert (sample - off_sample).abs().max() < 1e-4
443+
415444

416445
@slow
417446
class UNet2DConditionModelIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)