diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e814981a85c9..acb80882028b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -213,6 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) + is_custom_diffusion = all("custom_diffusion" in k for k in state_dict.keys()) if is_lora: lora_grouped_dict = defaultdict(dict) @@ -230,6 +231,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict ) attn_processors[key].load_state_dict(value_dict) + elif is_custom_diffusion: + # Needs to be implemented. + pass else: raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f2a5a376bf39..67f1e6e6c032 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -824,6 +824,67 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states +class CustomDiffusionAttnProcessor(nn.Module): + def __init__(self, train_kv=True, train_q=True, hidden_dim=None, cross_attention_dim=None): + super().__init__() + self.train_kv = train_kv + self.train_q = train_q + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim) + self.to_v_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim) + if self.train_q: + self.to_q_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim) + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + if attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, @@ -834,4 +895,5 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor2_0, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + CustomDiffusionAttnProcessor, ]