Skip to content

Commit cdf071a

Browse files
committed
for the (unused) attention bias which the baddbmm API requires us to pass in: create a smaller bias and broadcast it. this helps to prevent NaN result from baddbmm() on MPS on PyTorch 2.1.0.dev20230310.
1 parent c559479 commit cdf071a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,20 @@ def get_attention_scores(self, query, key, attention_mask=None):
298298
key = key.float()
299299

300300
if attention_mask is None:
301-
baddbmm_input = torch.empty(
302-
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
301+
batch_x_heads, query_tokens, _ = query.shape
302+
_, key_tokens, _ = key.shape
303+
# expanding dims isn't strictly necessary (baddbmm supports broadcasting bias),
304+
# but documents the expected shape without allocating any additional memory
305+
attention_bias = torch.zeros(1, 1, 1, dtype=query.dtype, device=query.device).expand(
306+
batch_x_heads, query_tokens, key_tokens
303307
)
304308
beta = 0
305309
else:
306-
baddbmm_input = attention_mask
310+
attention_bias = attention_mask
307311
beta = 1
308312

309313
attention_scores = torch.baddbmm(
310-
baddbmm_input,
314+
attention_bias,
311315
query,
312316
key.transpose(-1, -2),
313317
beta=beta,

0 commit comments

Comments
 (0)