Skip to content

Commit 89eb794

Browse files
committed
for the (unused) attention bias which the baddbmm API requires us to pass in: create a smaller bias and broadcast it. this helps to prevent NaN result from baddbmm() on MPS on PyTorch 2.1.0.dev20230310.
1 parent bbab855 commit 89eb794

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,20 @@ def get_attention_scores(self, query, key, attention_mask=None):
231231
key = key.float()
232232

233233
if attention_mask is None:
234-
baddbmm_input = torch.empty(
235-
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
236-
)
234+
batch_x_heads, query_tokens, _ = query.shape
235+
_, key_tokens, _ = key.shape
236+
# expanding dims isn't strictly necessary (baddbmm supports broadcasting bias),
237+
# but documents the expected shape without allocating any additional memory
238+
attention_bias = torch.zeros(
239+
1, 1, 1, dtype=query.dtype, device=query.device
240+
).expand(batch_x_heads, query_tokens, key_tokens)
237241
beta = 0
238242
else:
239-
baddbmm_input = attention_mask
243+
attention_bias = attention_mask
240244
beta = 1
241245

242246
attention_scores = torch.baddbmm(
243-
baddbmm_input,
247+
attention_bias,
244248
query,
245249
key.transpose(-1, -2),
246250
beta=beta,

0 commit comments

Comments
 (0)