Skip to content

Commit 4d35d7f

Browse files
Allow disabling torch 2_0 attention (#3273)
* Allow disabling torch 2_0 attention * make style * Update src/diffusers/models/attention.py
1 parent a7b0671 commit 4d35d7f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/models/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.proj_attn = nn.Linear(channels, channels, bias=True)
7272

7373
self._use_memory_efficient_attention_xformers = False
74+
self._use_2_0_attn = True
7475
self._attention_op = None
7576

7677
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
@@ -142,9 +143,8 @@ def forward(self, hidden_states):
142143

143144
scale = 1 / math.sqrt(self.channels / self.num_heads)
144145

145-
use_torch_2_0_attn = (
146-
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
147-
)
146+
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
147+
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
148148

149149
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
150150
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)

0 commit comments

Comments
 (0)