From 07ec4ab5d87a75ea59ce85cd0479c494b20dfb91 Mon Sep 17 00:00:00 2001 From: Haotian Liu <6631389+haotian-liu@users.noreply.github.com> Date: Wed, 22 Mar 2023 02:42:14 -0500 Subject: [PATCH] Use native memory efficient attention in PyTorch 2.0 if possible --- src/diffusers/models/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5c7e54e7cd32..caa021d67b98 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,6 +144,9 @@ def forward(self, hidden_states): query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op ) hidden_states = hidden_states.to(query_proj.dtype) + elif hasattr(F, 'scaled_dot_product_attention'): + # PyTorch 2.0: native Flash/memory_efficient_attention + hidden_states = F.scaled_dot_product_attention(query_proj, key_proj, value_proj) else: attention_scores = torch.baddbmm( torch.empty(