Skip to content

[DON'T MERGE] PoC Custom Diffusion #3088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
attn_processors = {}

is_lora = all("lora" in k for k in state_dict.keys())
is_custom_diffusion = all("custom_diffusion" in k for k in state_dict.keys())

if is_lora:
lora_grouped_dict = defaultdict(dict)
Expand All @@ -230,6 +231,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
)
attn_processors[key].load_state_dict(value_dict)

elif is_custom_diffusion:
# Needs to be implemented.
pass
else:
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")

Expand Down
62 changes: 62 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,67 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
return hidden_states


class CustomDiffusionAttnProcessor(nn.Module):
def __init__(self, train_kv=True, train_q=True, hidden_dim=None, cross_attention_dim=None):
super().__init__()
self.train_kv = train_kv
self.train_q = train_q

# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim)
self.to_v_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim)
if self.train_q:
self.to_q_custom_diffusion = nn.Linear(hidden_dim, cross_attention_dim)

def forward(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q:
query = self.to_q_custom_diffusion(hidden_states)
else:
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
if attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()

if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states)
value = self.to_v_custom_diffusion(encoder_hidden_states)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states


AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
Expand All @@ -834,4 +895,5 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
AttnAddedKVProcessor2_0,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
CustomDiffusionAttnProcessor,
]