diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 2ea2e7be58e8..baccdd83f202 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -105,6 +105,10 @@ def __init__( def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) + ) + if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP @@ -138,9 +142,28 @@ def set_use_memory_efficient_attention_xformers( except Exception as e: raise e - processor = XFormersCrossAttnProcessor(attention_op=attention_op) + if is_lora: + processor = LoRAXFormersCrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = XFormersCrossAttnProcessor(attention_op=attention_op) else: - processor = CrossAttnProcessor() + if is_lora: + processor = LoRACrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + else: + processor = CrossAttnProcessor() self.set_processor(processor) @@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, rank=4): super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) @@ -437,9 +464,14 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4): + def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) @@ -462,7 +494,9 @@ def __call__( key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj @@ -595,4 +629,6 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= SlicedAttnProcessor, CrossAttnAddedKVProcessor, SlicedAttnAddedKVProcessor, + LoRACrossAttnProcessor, + LoRAXFormersCrossAttnProcessor, ] diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 70d9d6a59881..6ee8c2ffc002 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -412,6 +412,35 @@ def test_lora_on_off(self): assert (sample - new_sample).abs().max() < 1e-4 assert (sample - old_sample).abs().max() < 1e-4 + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):