@@ -105,6 +105,11 @@ def __init__(
105
105
def set_use_memory_efficient_attention_xformers (
106
106
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
107
107
):
108
+ is_lora = (
109
+ hasattr (self , "processor" ) and
110
+ isinstance (self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor ))
111
+ )
112
+
108
113
if use_memory_efficient_attention_xformers :
109
114
if self .added_kv_proj_dim is not None :
110
115
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,10 +143,7 @@ def set_use_memory_efficient_attention_xformers(
138
143
except Exception as e :
139
144
raise e
140
145
141
- if (
142
- hasattr (self , "processor" ) and
143
- isinstance (self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor ))
144
- ):
146
+ if is_lora :
145
147
processor = LoRAXFormersCrossAttnProcessor (
146
148
hidden_size = self .processor .hidden_size ,
147
149
cross_attention_dim = self .processor .cross_attention_dim ,
@@ -152,7 +154,15 @@ def set_use_memory_efficient_attention_xformers(
152
154
else :
153
155
processor = XFormersCrossAttnProcessor (attention_op = attention_op )
154
156
else :
155
- processor = CrossAttnProcessor ()
157
+ if is_lora :
158
+ processor = LoRACrossAttnProcessor (
159
+ hidden_size = self .processor .hidden_size ,
160
+ cross_attention_dim = self .processor .cross_attention_dim ,
161
+ rank = self .processor .rank )
162
+ processor .load_state_dict (self .processor .state_dict ())
163
+ processor .to (self .processor .to_q_lora .up .weight .device )
164
+ else :
165
+ processor = CrossAttnProcessor ()
156
166
157
167
self .set_processor (processor )
158
168
0 commit comments