From 437fe0ec3a83af0eb4b08da48f0ce72145079c55 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 19 Sep 2022 15:27:27 +0200 Subject: [PATCH] Fix CrossAttention._sliced_attention --- src/diffusers/models/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e4cedbff8c9a..25e1ea28dcf0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -249,13 +249,15 @@ def reshape_batch_dim_to_heads(self, tensor): return tensor def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, dim = hidden_states.shape + batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) context = context if context is not None else hidden_states key = self.to_k(context) value = self.to_v(context) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value)