@@ -105,6 +105,10 @@ def __init__(
105
105
def set_use_memory_efficient_attention_xformers (
106
106
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
107
107
):
108
+ is_lora = hasattr (self , "processor" ) and isinstance (
109
+ self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor )
110
+ )
111
+
108
112
if use_memory_efficient_attention_xformers :
109
113
if self .added_kv_proj_dim is not None :
110
114
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,9 +142,28 @@ def set_use_memory_efficient_attention_xformers(
138
142
except Exception as e :
139
143
raise e
140
144
141
- processor = XFormersCrossAttnProcessor (attention_op = attention_op )
145
+ if is_lora :
146
+ processor = LoRAXFormersCrossAttnProcessor (
147
+ hidden_size = self .processor .hidden_size ,
148
+ cross_attention_dim = self .processor .cross_attention_dim ,
149
+ rank = self .processor .rank ,
150
+ attention_op = attention_op ,
151
+ )
152
+ processor .load_state_dict (self .processor .state_dict ())
153
+ processor .to (self .processor .to_q_lora .up .weight .device )
154
+ else :
155
+ processor = XFormersCrossAttnProcessor (attention_op = attention_op )
142
156
else :
143
- processor = CrossAttnProcessor ()
157
+ if is_lora :
158
+ processor = LoRACrossAttnProcessor (
159
+ hidden_size = self .processor .hidden_size ,
160
+ cross_attention_dim = self .processor .cross_attention_dim ,
161
+ rank = self .processor .rank ,
162
+ )
163
+ processor .load_state_dict (self .processor .state_dict ())
164
+ processor .to (self .processor .to_q_lora .up .weight .device )
165
+ else :
166
+ processor = CrossAttnProcessor ()
144
167
145
168
self .set_processor (processor )
146
169
@@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
324
347
def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
325
348
super ().__init__ ()
326
349
350
+ self .hidden_size = hidden_size
351
+ self .cross_attention_dim = cross_attention_dim
352
+ self .rank = rank
353
+
327
354
self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
328
355
self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
329
356
self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
@@ -437,9 +464,14 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
437
464
438
465
439
466
class LoRAXFormersCrossAttnProcessor (nn .Module ):
440
- def __init__ (self , hidden_size , cross_attention_dim , rank = 4 ):
467
+ def __init__ (self , hidden_size , cross_attention_dim , rank = 4 , attention_op : Optional [ Callable ] = None ):
441
468
super ().__init__ ()
442
469
470
+ self .hidden_size = hidden_size
471
+ self .cross_attention_dim = cross_attention_dim
472
+ self .rank = rank
473
+ self .attention_op = attention_op
474
+
443
475
self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
444
476
self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
445
477
self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
@@ -462,7 +494,9 @@ def __call__(
462
494
key = attn .head_to_batch_dim (key ).contiguous ()
463
495
value = attn .head_to_batch_dim (value ).contiguous ()
464
496
465
- hidden_states = xformers .ops .memory_efficient_attention (query , key , value , attn_bias = attention_mask )
497
+ hidden_states = xformers .ops .memory_efficient_attention (
498
+ query , key , value , attn_bias = attention_mask , op = self .attention_op
499
+ )
466
500
hidden_states = attn .batch_to_head_dim (hidden_states )
467
501
468
502
# linear proj
@@ -595,4 +629,6 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
595
629
SlicedAttnProcessor ,
596
630
CrossAttnAddedKVProcessor ,
597
631
SlicedAttnAddedKVProcessor ,
632
+ LoRACrossAttnProcessor ,
633
+ LoRAXFormersCrossAttnProcessor ,
598
634
]
0 commit comments