From dc11abf1f3116932507abefbe5347f313f3ac634 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 24 Apr 2023 22:43:41 +0200 Subject: [PATCH] [Bug fix] Fix batch size attention head size mismatch --- src/diffusers/models/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1085c452b076..8e537c6f3680 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -86,8 +86,10 @@ def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): head_size = self.num_heads if unmerge_head_and_batch: - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + batch_head_size, seq_len, dim = tensor.shape + batch_size = batch_head_size // head_size + + tensor = tensor.reshape(batch_size, head_size, seq_len, dim) else: batch_size, _, seq_len, dim = tensor.shape