diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6701122fc13b..b7de2c73d630 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -298,16 +298,20 @@ def get_attention_scores(self, query, key, attention_mask=None): key = key.float() if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + batch_x_heads, query_tokens, _ = query.shape + _, key_tokens, _ = key.shape + # expanding dims isn't strictly necessary (baddbmm supports broadcasting bias), + # but documents the expected shape without allocating any additional memory + attention_bias = torch.zeros(1, 1, 1, dtype=query.dtype, device=query.device).expand( + batch_x_heads, query_tokens, key_tokens ) beta = 0 else: - baddbmm_input = attention_mask + attention_bias = attention_mask beta = 1 attention_scores = torch.baddbmm( - baddbmm_input, + attention_bias, query, key.transpose(-1, -2), beta=beta,