File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,7 @@ def __init__(
71
71
self .proj_attn = nn .Linear (channels , channels , bias = True )
72
72
73
73
self ._use_memory_efficient_attention_xformers = False
74
+ self ._use_2_0_attn = True
74
75
self ._attention_op = None
75
76
76
77
def reshape_heads_to_batch_dim (self , tensor , merge_head_and_batch = True ):
@@ -142,9 +143,8 @@ def forward(self, hidden_states):
142
143
143
144
scale = 1 / math .sqrt (self .channels / self .num_heads )
144
145
145
- use_torch_2_0_attn = (
146
- hasattr (F , "scaled_dot_product_attention" ) and not self ._use_memory_efficient_attention_xformers
147
- )
146
+ _use_2_0_attn = self ._use_2_0_attn and not self ._use_memory_efficient_attention_xformers
147
+ use_torch_2_0_attn = hasattr (F , "scaled_dot_product_attention" ) and _use_2_0_attn
148
148
149
149
query_proj = self .reshape_heads_to_batch_dim (query_proj , merge_head_and_batch = not use_torch_2_0_attn )
150
150
key_proj = self .reshape_heads_to_batch_dim (key_proj , merge_head_and_batch = not use_torch_2_0_attn )
You can’t perform that action at this time.
0 commit comments