-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Kandinsky 3.0] Follow-up TODOs #5944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
293c480
remove kandinsky specific attention and attention processor
yiyixuxu e5a1f32
Merge branch 'main' into kand-3
yiyixuxu eca5c18
remove set_default_layer and set_default_item
yiyixuxu b8bb288
remove SinusoidalPosEmb
yiyixuxu 84ce3d6
more
yiyixuxu 4b5fe93
Merge branch 'main' into kand-3
yiyixuxu 123dafc
disable batch test
yiyixuxu 145ddad
fix cpu model offload
yiyixuxu 5a2dd24
take off last to-do
yiyixuxu 3318034
Merge branch 'main' into kand-3
yiyixuxu 89fdee4
another to-do
yiyixuxu d9406cb
Merge branch 'kand-3' of github.com:huggingface/diffusers into kand-3
yiyixuxu e408cdf
refactor
yiyixuxu 7dced94
add callback and latent output for text2img
yiyixuxu 51fe17b
refactor img2img
yiyixuxu f5cfa5a
change unet file name
yiyixuxu 334cd2e
add doc string
yiyixuxu dbf5135
change pipeline file name
yiyixuxu 25c4e07
fix failing test
yiyixuxu dd198cb
rename unet file
yiyixuxu d60bc4e
testing prints
yiyixuxu 1d170cc
style
yiyixuxu c4eae7e
allow pass prompt_embeds
yiyixuxu 226f755
offload
yiyixuxu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import einsum, nn | ||
| from torch import nn | ||
|
|
||
| from ..utils import USE_PEFT_BACKEND, deprecate, logging | ||
| from ..utils.import_utils import is_xformers_available | ||
|
|
@@ -109,15 +109,17 @@ def __init__( | |
| residual_connection: bool = False, | ||
| _from_deprecated_attn_block: bool = False, | ||
| processor: Optional["AttnProcessor"] = None, | ||
| out_dim: int = None, | ||
| ): | ||
| super().__init__() | ||
| self.inner_dim = dim_head * heads | ||
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads | ||
| self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | ||
| self.upcast_attention = upcast_attention | ||
| self.upcast_softmax = upcast_softmax | ||
| self.rescale_output_factor = rescale_output_factor | ||
| self.residual_connection = residual_connection | ||
| self.dropout = dropout | ||
| self.out_dim = out_dim if out_dim is not None else query_dim | ||
|
|
||
| # we make use of this private variable to know whether this class is loaded | ||
| # with an deprecated state dict so that we can convert it on the fly | ||
|
|
@@ -126,7 +128,7 @@ def __init__( | |
| self.scale_qk = scale_qk | ||
| self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | ||
|
|
||
| self.heads = heads | ||
| self.heads = out_dim // dim_head if out_dim is not None else heads | ||
| # for slice_size > 0 the attention score computation | ||
| # is split across the batch axis to save memory | ||
| # You can set slice_size with `set_attention_slice` | ||
|
|
@@ -193,7 +195,7 @@ def __init__( | |
| self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) | ||
|
|
||
| self.to_out = nn.ModuleList([]) | ||
| self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) | ||
| self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) | ||
| self.to_out.append(nn.Dropout(dropout)) | ||
|
|
||
| # set attention processor | ||
|
|
@@ -2219,44 +2221,6 @@ def __call__( | |
| return hidden_states | ||
|
|
||
|
|
||
| # TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
| # this way torch.compile and co. will work as well | ||
| class Kandi3AttnProcessor: | ||
| r""" | ||
| Default kandinsky3 proccesor for performing attention-related computations. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _reshape(hid_states, h): | ||
| b, n, f = hid_states.shape | ||
| d = f // h | ||
| return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3) | ||
|
|
||
| def __call__( | ||
| self, | ||
| attn, | ||
| x, | ||
| context, | ||
| context_mask=None, | ||
| ): | ||
| query = self._reshape(attn.to_q(x), h=attn.num_heads) | ||
| key = self._reshape(attn.to_k(context), h=attn.num_heads) | ||
| value = self._reshape(attn.to_v(context), h=attn.num_heads) | ||
|
|
||
| attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key) | ||
|
|
||
| if context_mask is not None: | ||
| max_neg_value = -torch.finfo(attention_matrix.dtype).max | ||
| context_mask = context_mask.unsqueeze(1).unsqueeze(1) | ||
| attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value) | ||
| attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1) | ||
|
|
||
| out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value) | ||
| out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1) | ||
| out = attn.to_out[0](out) | ||
| return out | ||
|
|
||
|
|
||
| LORA_ATTENTION_PROCESSORS = ( | ||
| LoRAAttnProcessor, | ||
| LoRAAttnProcessor2_0, | ||
|
|
@@ -2282,7 +2246,6 @@ def __call__( | |
| LoRAXFormersAttnProcessor, | ||
| IPAdapterAttnProcessor, | ||
| IPAdapterAttnProcessor2_0, | ||
| Kandi3AttnProcessor, | ||
| ) | ||
|
|
||
| AttentionProcessor = Union[ | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
out_dimdifferent fromquery_dimhere?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten
The only difference is the
to_outlayer here - Kandinsky attention output does not change the dimension frominner_dimwhile our attention class will project the output toquery_dim. I added anout_dimfor this purpose, but we can add a different config if it makes more sense!diffusers/src/diffusers/models/unet_kandi3.py
Line 453 in d1b2a1a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works! Makes sense