Skip to content

Commit b3ab842

Browse files
committed
support disabling xformers
1 parent 2dc58e4 commit b3ab842

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def __init__(
105105
def set_use_memory_efficient_attention_xformers(
106106
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
107107
):
108+
is_lora = (
109+
hasattr(self, "processor") and
110+
isinstance(self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor))
111+
)
112+
108113
if use_memory_efficient_attention_xformers:
109114
if self.added_kv_proj_dim is not None:
110115
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,10 +143,7 @@ def set_use_memory_efficient_attention_xformers(
138143
except Exception as e:
139144
raise e
140145

141-
if (
142-
hasattr(self, "processor") and
143-
isinstance(self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor))
144-
):
146+
if is_lora:
145147
processor = LoRAXFormersCrossAttnProcessor(
146148
hidden_size=self.processor.hidden_size,
147149
cross_attention_dim=self.processor.cross_attention_dim,
@@ -152,7 +154,15 @@ def set_use_memory_efficient_attention_xformers(
152154
else:
153155
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
154156
else:
155-
processor = CrossAttnProcessor()
157+
if is_lora:
158+
processor = LoRACrossAttnProcessor(
159+
hidden_size=self.processor.hidden_size,
160+
cross_attention_dim=self.processor.cross_attention_dim,
161+
rank=self.processor.rank)
162+
processor.load_state_dict(self.processor.state_dict())
163+
processor.to(self.processor.to_q_lora.up.weight.device)
164+
else:
165+
processor = CrossAttnProcessor()
156166

157167
self.set_processor(processor)
158168

0 commit comments

Comments
 (0)