Skip to content

Commit 1395d64

Browse files
committed
reformat
1 parent b3ab842 commit 1395d64

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ def __init__(
105105
def set_use_memory_efficient_attention_xformers(
106106
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
107107
):
108-
is_lora = (
109-
hasattr(self, "processor") and
110-
isinstance(self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor))
108+
is_lora = hasattr(self, "processor") and isinstance(
109+
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
111110
)
112111

113112
if use_memory_efficient_attention_xformers:
@@ -148,7 +147,8 @@ def set_use_memory_efficient_attention_xformers(
148147
hidden_size=self.processor.hidden_size,
149148
cross_attention_dim=self.processor.cross_attention_dim,
150149
rank=self.processor.rank,
151-
attention_op=attention_op)
150+
attention_op=attention_op,
151+
)
152152
processor.load_state_dict(self.processor.state_dict())
153153
processor.to(self.processor.to_q_lora.up.weight.device)
154154
else:
@@ -158,7 +158,8 @@ def set_use_memory_efficient_attention_xformers(
158158
processor = LoRACrossAttnProcessor(
159159
hidden_size=self.processor.hidden_size,
160160
cross_attention_dim=self.processor.cross_attention_dim,
161-
rank=self.processor.rank)
161+
rank=self.processor.rank,
162+
)
162163
processor.load_state_dict(self.processor.state_dict())
163164
processor.to(self.processor.to_q_lora.up.weight.device)
164165
else:
@@ -493,7 +494,9 @@ def __call__(
493494
key = attn.head_to_batch_dim(key).contiguous()
494495
value = attn.head_to_batch_dim(value).contiguous()
495496

496-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)
497+
hidden_states = xformers.ops.memory_efficient_attention(
498+
query, key, value, attn_bias=attention_mask, op=self.attention_op
499+
)
497500
hidden_states = attn.batch_to_head_dim(hidden_states)
498501

499502
# linear proj

0 commit comments

Comments
 (0)