Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]

Expand Down Expand Up @@ -64,7 +64,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

Expand Down
49 changes: 6 additions & 43 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.nn.functional as F
from torch import einsum, nn
from torch import nn

from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -109,15 +109,17 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is out_dim different from query_dim here?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
The only difference is the to_outlayer here - Kandinsky attention output does not change the dimension from inner_dim while our attention class will project the output to query_dim. I added an out_dim for this purpose, but we can add a different config if it makes more sense!

self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works! Makes sense

):
super().__init__()
self.inner_dim = dim_head * heads
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim

# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand All @@ -126,7 +128,7 @@ def __init__(
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0

self.heads = heads
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
Expand Down Expand Up @@ -193,7 +195,7 @@ def __init__(
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -2219,44 +2221,6 @@ def __call__(
return hidden_states


# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""

@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)

def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)

attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)

if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand All @@ -2282,7 +2246,6 @@ def __call__(
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)

AttentionProcessor = Union[
Expand Down
Loading