Skip to content

Commit 5a47af1

Browse files
pcuencaanton-l
andauthored
mps changes for PyTorch 1.13 (huggingface#926)
* Docs: refer to pre-RC version of PyTorch 1.13.0. * Remove temporary workaround for unavailable op. * Update comment to make it less ambiguous. * Remove use of contiguous in mps. It appears to not longer be necessary. * Special case: use einsum for much better performance in mps * Update mps docs. * Minor doc update. * Accept suggestion Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 6ad4e1e commit 5a47af1

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

models/attention.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def _set_attention_slice(self, slice_size):
207207
self.attn2._slice_size = slice_size
208208

209209
def forward(self, hidden_states, context=None):
210-
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
211210
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
212211
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
213212
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -288,10 +287,19 @@ def forward(self, hidden_states, context=None, mask=None):
288287

289288
def _attention(self, query, key, value):
290289
# TODO: use baddbmm for better performance
291-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
290+
if query.device.type == "mps":
291+
# Better performance on mps (~20-25%)
292+
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
293+
else:
294+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
292295
attention_probs = attention_scores.softmax(dim=-1)
293296
# compute attention output
294-
hidden_states = torch.matmul(attention_probs, value)
297+
298+
if query.device.type == "mps":
299+
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
300+
else:
301+
hidden_states = torch.matmul(attention_probs, value)
302+
295303
# reshape hidden_states
296304
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
297305
return hidden_states
@@ -305,11 +313,21 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
305313
for i in range(hidden_states.shape[0] // slice_size):
306314
start_idx = i * slice_size
307315
end_idx = (i + 1) * slice_size
308-
attn_slice = (
309-
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
310-
) # TODO: use baddbmm for better performance
316+
if query.device.type == "mps":
317+
# Better performance on mps (~20-25%)
318+
attn_slice = (
319+
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
320+
* self.scale
321+
)
322+
else:
323+
attn_slice = (
324+
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
325+
) # TODO: use baddbmm for better performance
311326
attn_slice = attn_slice.softmax(dim=-1)
312-
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
327+
if query.device.type == "mps":
328+
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
329+
else:
330+
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
313331

314332
hidden_states[start_idx:end_idx] = attn_slice
315333

models/resnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
492492
kernel_h, kernel_w = kernel.shape
493493

494494
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
495-
496-
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
497-
if tensor.device.type == "mps":
498-
out = out.to("cpu")
499495
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
500496
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
501497

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __call__(
287287
latents_dtype = text_embeddings.dtype
288288
if latents is None:
289289
if self.device.type == "mps":
290-
# randn does not exist on mps
290+
# randn does not work reproducibly on mps
291291
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
292292
self.device
293293
)

0 commit comments

Comments
 (0)