Skip to content

Conversation

@ydshieh
Copy link
Contributor

@ydshieh ydshieh commented Sep 19, 2022

Fix CrossAttention._sliced_attention. See the review comment.

To Reproduce without this PR

from diffusers.models.attention import CrossAttention
import numpy as np
import torch

# # This is OK: i.e. if `query_dim` == `inner_dim`
# N, T_query, T_context, query_dim, context_dim, heads, dim_head = (2, 5, 3, 16, 4, 2, 8)

# This fails
N, T_query, T_context, query_dim, context_dim, heads, dim_head = (2, 5, 3, 8, 4, 2, 8)

sample = np.random.default_rng().standard_normal(size=(N, T_query, query_dim), dtype=np.float32)
context = np.random.default_rng().standard_normal(size=(N, T_context, context_dim), dtype=np.float32)
pt_sample = torch.tensor(sample, dtype=torch.float32)
pt_context = torch.tensor(context, dtype=torch.float32)


pt_layer = CrossAttention(query_dim=query_dim, context_dim=context_dim, heads=heads, dim_head=dim_head)
# Use sliced attention
pt_layer._slice_size = 1

with torch.no_grad():
    pt_output = pt_layer(pt_sample, context=pt_context)

which gives

Traceback (most recent call last):
  File "C:\Users\33611\Desktop\Project\diffusers\debug.py", line 18, in <module>
    pt_output = pt_layer(pt_sample, context=pt_context)
  File "C:\Users\33611\miniconda3\envs\py39\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\33611\Desktop\Project\diffusers\src\diffusers\models\attention.py", line 270, in forward
    hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
  File "C:\Users\33611\Desktop\Project\diffusers\src\diffusers\models\attention.py", line 296, in _sliced_attention
    hidden_states[start_idx:end_idx] = attn_slice
RuntimeError: The expanded size of the tensor (4) must match the existing size (8) at non-singleton dimension 2.  Target sizes: [1, 5, 4].  Tensor sizes: [5, 8]

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh force-pushed the fix_sliced_attention branch from 2c4047f to cf9957e Compare September 19, 2022 12:33
@ydshieh ydshieh force-pushed the fix_sliced_attention branch from cf9957e to 437fe0e Compare September 19, 2022 13:28
key = self.to_k(context)
value = self.to_v(context)

dim = query.shape[-1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dim should be computed with the projected q, k, v (i.e. inner_dim), not with hidden_states.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current CI passes, but the test has inner_dim == hidden_states.shape[-1] (both 64). When this is not the case, we get an error.

@patrickvonplaten patrickvonplaten marked this pull request as ready for review September 19, 2022 13:43
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Thanks a lot @ydshieh

@ydshieh ydshieh merged commit 84616b5 into main Sep 19, 2022
@ydshieh ydshieh deleted the fix_sliced_attention branch September 19, 2022 16:07
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix CrossAttention._sliced_attention

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants