Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"UNet3DConditionModel",
"UNetMotionModel",
"VQModel",
'UNetKandi3'
]
)

Expand Down Expand Up @@ -290,6 +291,8 @@
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"KandinskyV3Pipeline",
"KandinskyV3Img2ImgPipeline",
]
)

Expand Down Expand Up @@ -459,6 +462,7 @@
UNet3DConditionModel,
UNetMotionModel,
VQModel,
UNetKandi3,
)
from .optimization import (
get_constant_schedule,
Expand Down Expand Up @@ -577,6 +581,8 @@
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
KandinskyV3Pipeline,
KandinskyV3Img2ImgPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
LDMTextToImagePipeline,
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]
_import_structure["unet_kandi3"] = ["UNetKandi3"]


if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
Expand Down Expand Up @@ -65,13 +67,14 @@
from .unet_3d_condition import UNet3DConditionModel
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

from .unet_kandi3 import UNetKandi3
if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL

else:

import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
90 changes: 90 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer
from einops import rearrange, repeat
from torch import einsum


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

def exist(item):
return item is not None

if is_xformers_available():
import xformers
Expand Down Expand Up @@ -1974,12 +1978,96 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs)

class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""

def __call__(
self,
attn,
x,
context,
context_mask=None,
image_mask=None

):
query = rearrange(attn.to_query(x), 'b n (h d) -> b h n d', h=attn.num_heads)
key = rearrange(attn.to_key(context), 'b n (h d) -> b h n d', h=attn.num_heads)
value = rearrange(attn.to_value(context), 'b n (h d) -> b h n d', h=attn.num_heads)

attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key)
if exist(image_mask) and exist(context_mask):
image_mask = rearrange(image_mask, 'b i -> b 1 i 1')
image_text_mask_1 = rearrange((context_mask == 1).type(image_mask.dtype), 'b j -> b 1 1 j')
image_text_mask_2 = rearrange((context_mask == 2).type(image_mask.dtype), 'b j -> b 1 1 j')

image_mask_max = image_mask.amax(-2, keepdim=True)
max_attention = rearrange(attention_matrix.amax((-2, -1)), 'b h -> b h 1 1')
attention_matrix = attention_matrix + max_attention * (image_mask * image_text_mask_1)
attention_matrix = attention_matrix + max_attention * ((image_mask_max - image_mask) * image_text_mask_2)

if exist(context_mask):
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
out = rearrange(out, 'b h n d -> b n (h d)')
out = attn.output_layer(out)
return out

class LoraKandi3AttnProcessor(nn.Module):

def __init__(self, in_channels, out_channels, context_dim, head_dim=64, rank=4, network_alpha=None):
super().__init__()

self.to_query_lora = LoRALinearLayer(in_channels, out_channels, rank, network_alpha)
self.to_key_lora = LoRALinearLayer(context_dim, out_channels, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(context_dim, out_channels, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_channels, out_channels, rank, network_alpha)

def forward(
self,
attn,
x,
context,
context_mask=None,
image_mask=None,
scale=1):
query = rearrange(attn.to_query(x) + self.to_query_lora(x), 'b n (h d) -> b h n d', h=attn.num_heads)
key = rearrange(attn.to_key(context) + self.to_key_lora(context), 'b n (h d) -> b h n d', h=attn.num_heads)
value = rearrange(attn.to_value(context) + self.to_v_lora(context), 'b n (h d) -> b h n d', h=attn.num_heads)

attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key)
if exist(image_mask) and exist(context_mask):
image_mask = rearrange(image_mask, 'b i -> b 1 i 1')
image_text_mask_1 = rearrange((context_mask == 1).type(image_mask.dtype), 'b j -> b 1 1 j')
image_text_mask_2 = rearrange((context_mask == 2).type(image_mask.dtype), 'b j -> b 1 1 j')

image_mask_max = image_mask.amax(-2, keepdim=True)
max_attention = rearrange(attention_matrix.amax((-2, -1)), 'b h -> b h 1 1')
attention_matrix = attention_matrix + max_attention * (image_mask * image_text_mask_1)
attention_matrix = attention_matrix + max_attention * ((image_mask_max - image_mask) * image_text_mask_2)

if exist(context_mask):
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
out = rearrange(out, 'b h n d -> b n (h d)')
out = attn.output_layer(out) + self.to_out_lora(out)
return out

LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoraKandi3AttnProcessor,
)

ADDED_KV_ATTENTION_PROCESSORS = (
Expand Down Expand Up @@ -2012,6 +2100,8 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Kandi3AttnProcessor,
LoraKandi3AttnProcessor,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down
Loading