From cdf071ae76589f21ce17544f97ce247113be62cd Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 9 May 2023 23:49:04 +0100 Subject: [PATCH] for the (unused) attention bias which the baddbmm API requires us to pass in: create a smaller bias and broadcast it. this helps to prevent NaN result from baddbmm() on MPS on PyTorch 2.1.0.dev20230310. --- src/diffusers/models/attention_processor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6701122fc13b..b7de2c73d630 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -298,16 +298,20 @@ def get_attention_scores(self, query, key, attention_mask=None): key = key.float() if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + batch_x_heads, query_tokens, _ = query.shape + _, key_tokens, _ = key.shape + # expanding dims isn't strictly necessary (baddbmm supports broadcasting bias), + # but documents the expected shape without allocating any additional memory + attention_bias = torch.zeros(1, 1, 1, dtype=query.dtype, device=query.device).expand( + batch_x_heads, query_tokens, key_tokens ) beta = 0 else: - baddbmm_input = attention_mask + attention_bias = attention_mask beta = 1 attention_scores = torch.baddbmm( - baddbmm_input, + attention_bias, query, key.transpose(-1, -2), beta=beta,