1616
1717import torch
1818import torch .nn .functional as F
19- from torch import einsum , nn
19+ from torch import nn
2020
2121from ..utils import USE_PEFT_BACKEND , deprecate , logging
2222from ..utils .import_utils import is_xformers_available
@@ -109,15 +109,17 @@ def __init__(
109109 residual_connection : bool = False ,
110110 _from_deprecated_attn_block : bool = False ,
111111 processor : Optional ["AttnProcessor" ] = None ,
112+ out_dim : int = None ,
112113 ):
113114 super ().__init__ ()
114- self .inner_dim = dim_head * heads
115+ self .inner_dim = out_dim if out_dim is not None else dim_head * heads
115116 self .cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
116117 self .upcast_attention = upcast_attention
117118 self .upcast_softmax = upcast_softmax
118119 self .rescale_output_factor = rescale_output_factor
119120 self .residual_connection = residual_connection
120121 self .dropout = dropout
122+ self .out_dim = out_dim if out_dim is not None else query_dim
121123
122124 # we make use of this private variable to know whether this class is loaded
123125 # with an deprecated state dict so that we can convert it on the fly
@@ -126,7 +128,7 @@ def __init__(
126128 self .scale_qk = scale_qk
127129 self .scale = dim_head ** - 0.5 if self .scale_qk else 1.0
128130
129- self .heads = heads
131+ self .heads = out_dim // dim_head if out_dim is not None else heads
130132 # for slice_size > 0 the attention score computation
131133 # is split across the batch axis to save memory
132134 # You can set slice_size with `set_attention_slice`
@@ -193,7 +195,7 @@ def __init__(
193195 self .add_v_proj = linear_cls (added_kv_proj_dim , self .inner_dim )
194196
195197 self .to_out = nn .ModuleList ([])
196- self .to_out .append (linear_cls (self .inner_dim , query_dim , bias = out_bias ))
198+ self .to_out .append (linear_cls (self .inner_dim , self . out_dim , bias = out_bias ))
197199 self .to_out .append (nn .Dropout (dropout ))
198200
199201 # set attention processor
@@ -2219,44 +2221,6 @@ def __call__(
22192221 return hidden_states
22202222
22212223
2222- # TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
2223- # this way torch.compile and co. will work as well
2224- class Kandi3AttnProcessor :
2225- r"""
2226- Default kandinsky3 proccesor for performing attention-related computations.
2227- """
2228-
2229- @staticmethod
2230- def _reshape (hid_states , h ):
2231- b , n , f = hid_states .shape
2232- d = f // h
2233- return hid_states .unsqueeze (- 1 ).reshape (b , n , h , d ).permute (0 , 2 , 1 , 3 )
2234-
2235- def __call__ (
2236- self ,
2237- attn ,
2238- x ,
2239- context ,
2240- context_mask = None ,
2241- ):
2242- query = self ._reshape (attn .to_q (x ), h = attn .num_heads )
2243- key = self ._reshape (attn .to_k (context ), h = attn .num_heads )
2244- value = self ._reshape (attn .to_v (context ), h = attn .num_heads )
2245-
2246- attention_matrix = einsum ("b h i d, b h j d -> b h i j" , query , key )
2247-
2248- if context_mask is not None :
2249- max_neg_value = - torch .finfo (attention_matrix .dtype ).max
2250- context_mask = context_mask .unsqueeze (1 ).unsqueeze (1 )
2251- attention_matrix = attention_matrix .masked_fill (~ (context_mask != 0 ), max_neg_value )
2252- attention_matrix = (attention_matrix * attn .scale ).softmax (dim = - 1 )
2253-
2254- out = einsum ("b h i j, b h j d -> b h i d" , attention_matrix , value )
2255- out = out .permute (0 , 2 , 1 , 3 ).reshape (out .shape [0 ], out .shape [2 ], - 1 )
2256- out = attn .to_out [0 ](out )
2257- return out
2258-
2259-
22602224LORA_ATTENTION_PROCESSORS = (
22612225 LoRAAttnProcessor ,
22622226 LoRAAttnProcessor2_0 ,
@@ -2282,7 +2246,6 @@ def __call__(
22822246 LoRAXFormersAttnProcessor ,
22832247 IPAdapterAttnProcessor ,
22842248 IPAdapterAttnProcessor2_0 ,
2285- Kandi3AttnProcessor ,
22862249)
22872250
22882251AttentionProcessor = Union [
0 commit comments