Skip to content

Commit c5933c9

Browse files
[Bug fix] Fix batch size attention head size mismatch (#3214)
1 parent 91a2a80 commit c5933c9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
8686
head_size = self.num_heads
8787

8888
if unmerge_head_and_batch:
89-
batch_size, seq_len, dim = tensor.shape
90-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
89+
batch_head_size, seq_len, dim = tensor.shape
90+
batch_size = batch_head_size // head_size
91+
92+
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
9193
else:
9294
batch_size, _, seq_len, dim = tensor.shape
9395

0 commit comments

Comments
 (0)