@@ -105,9 +105,8 @@ def __init__(
105
105
def set_use_memory_efficient_attention_xformers (
106
106
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
107
107
):
108
- is_lora = (
109
- hasattr (self , "processor" ) and
110
- isinstance (self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor ))
108
+ is_lora = hasattr (self , "processor" ) and isinstance (
109
+ self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor )
111
110
)
112
111
113
112
if use_memory_efficient_attention_xformers :
@@ -148,7 +147,8 @@ def set_use_memory_efficient_attention_xformers(
148
147
hidden_size = self .processor .hidden_size ,
149
148
cross_attention_dim = self .processor .cross_attention_dim ,
150
149
rank = self .processor .rank ,
151
- attention_op = attention_op )
150
+ attention_op = attention_op ,
151
+ )
152
152
processor .load_state_dict (self .processor .state_dict ())
153
153
processor .to (self .processor .to_q_lora .up .weight .device )
154
154
else :
@@ -158,7 +158,8 @@ def set_use_memory_efficient_attention_xformers(
158
158
processor = LoRACrossAttnProcessor (
159
159
hidden_size = self .processor .hidden_size ,
160
160
cross_attention_dim = self .processor .cross_attention_dim ,
161
- rank = self .processor .rank )
161
+ rank = self .processor .rank ,
162
+ )
162
163
processor .load_state_dict (self .processor .state_dict ())
163
164
processor .to (self .processor .to_q_lora .up .weight .device )
164
165
else :
@@ -493,7 +494,9 @@ def __call__(
493
494
key = attn .head_to_batch_dim (key ).contiguous ()
494
495
value = attn .head_to_batch_dim (value ).contiguous ()
495
496
496
- hidden_states = xformers .ops .memory_efficient_attention (query , key , value , attn_bias = attention_mask , op = self .attention_op )
497
+ hidden_states = xformers .ops .memory_efficient_attention (
498
+ query , key , value , attn_bias = attention_mask , op = self .attention_op
499
+ )
497
500
hidden_states = attn .batch_to_head_dim (hidden_states )
498
501
499
502
# linear proj
0 commit comments