diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1085c452b076..8e537c6f3680 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -86,8 +86,10 @@ def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): head_size = self.num_heads if unmerge_head_and_batch: - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + batch_head_size, seq_len, dim = tensor.shape + batch_size = batch_head_size // head_size + + tensor = tensor.reshape(batch_size, head_size, seq_len, dim) else: batch_size, _, seq_len, dim = tensor.shape