Skip to content

Commit e619db2

Browse files
authored
mps cross-attention hack: don't crash on fp16 (#2258)
* mps cross-attention hack: don't crash on fp16 * Make conversion explicit.
1 parent 111228c commit e619db2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
251251
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
252252
# Instead, we can manually construct the padding tensor.
253253
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
254-
padding = torch.zeros(padding_shape, device=attention_mask.device)
254+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
255255
attention_mask = torch.concat([attention_mask, padding], dim=2)
256256
else:
257257
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

0 commit comments

Comments
 (0)