Skip to content

Commit cf9957e

Browse files
committed
Fix CrossAttention._sliced_attention
1 parent 0424615 commit cf9957e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,15 @@ def reshape_batch_dim_to_heads(self, tensor):
249249
return tensor
250250

251251
def forward(self, hidden_states, context=None, mask=None):
252-
batch_size, sequence_length, dim = hidden_states.shape
252+
batch_size, sequence_length, _ = hidden_states.shape
253253

254254
query = self.to_q(hidden_states)
255255
context = context if context is not None else hidden_states
256256
key = self.to_k(context)
257257
value = self.to_v(context)
258258

259+
dim = query.shape[-1]
260+
259261
query = self.reshape_heads_to_batch_dim(query)
260262
key = self.reshape_heads_to_batch_dim(key)
261263
value = self.reshape_heads_to_batch_dim(value)
@@ -283,7 +285,7 @@ def _attention(self, query, key, value):
283285
def _sliced_attention(self, query, key, value, sequence_length, dim):
284286
batch_size_attention = query.shape[0]
285287
hidden_states = torch.zeros(
286-
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
288+
(batch_size_attention, sequence_length, dim), device=query.device, dtype=query.dtype
287289
)
288290
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
289291
for i in range(hidden_states.shape[0] // slice_size):

0 commit comments

Comments
 (0)