Skip to content

Commit a29c689

Browse files
2x speedup using memory efficient attention
1 parent 761f029 commit a29c689

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def run(self):
182182
else:
183183
extras["flax"] = deps_list("jax", "jaxlib", "flax")
184184

185-
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
185+
extras["dev"] = (
186+
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
187+
)
186188

187189
install_requires = [
188190
deps["importlib_metadata"],

src/diffusers/models/attention.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
import math
2-
from typing import Optional
2+
import os
3+
from inspect import isfunction
4+
from typing import Any, Optional
35

46
import torch
57
import torch.nn.functional as F
68
from torch import nn
79

10+
import xformers
11+
import xformers.ops
12+
from einops import rearrange
13+
14+
15+
_USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1
16+
17+
18+
def exists(val):
19+
return val is not None
20+
21+
22+
def default(val, d):
23+
if exists(val):
24+
return val
25+
return d() if isfunction(d) else d
26+
827

928
class AttentionBlock(nn.Module):
1029
"""
@@ -177,11 +196,12 @@ def __init__(
177196
checkpoint: bool = True,
178197
):
179198
super().__init__()
180-
self.attn1 = CrossAttention(
199+
AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention
200+
self.attn1 = AttentionBuilder(
181201
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
182202
) # is a self-attention
183203
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
184-
self.attn2 = CrossAttention(
204+
self.attn2 = AttentionBuilder(
185205
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
186206
) # is self-attn if context is none
187207
self.norm1 = nn.LayerNorm(dim)
@@ -201,6 +221,76 @@ def forward(self, hidden_states, context=None):
201221
return hidden_states
202222

203223

224+
class MemoryEfficientCrossAttention(nn.Module):
225+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
226+
super().__init__()
227+
inner_dim = dim_head * heads
228+
context_dim = default(context_dim, query_dim)
229+
230+
self.scale = dim_head**-0.5
231+
self.heads = heads
232+
233+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
234+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
235+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
236+
237+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
238+
self.attention_op: Optional[Any] = None
239+
240+
def _maybe_init(self, x):
241+
"""
242+
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
243+
: B, Head, Length
244+
"""
245+
if self.attention_op is not None:
246+
return
247+
248+
_, K, M = x.shape
249+
try:
250+
self.attention_op = xformers.ops.AttentionOpDispatch(
251+
dtype=x.dtype,
252+
device=x.device,
253+
k=K,
254+
attn_bias_type=type(None),
255+
has_dropout=False,
256+
kv_len=M,
257+
q_len=M,
258+
).op
259+
260+
except NotImplementedError as err:
261+
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
262+
263+
def forward(self, x, context=None, mask=None):
264+
h = self.heads
265+
266+
q = self.to_q(x)
267+
context = default(context, x)
268+
k = self.to_k(context)
269+
v = self.to_v(context)
270+
271+
q, k, v = map(
272+
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h).contiguous(),
273+
(q, k, v),
274+
)
275+
276+
# init the attention op, if required, using the proper dimensions
277+
self._maybe_init(q)
278+
279+
# actually compute the attention, what we cannot get enough of
280+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
281+
282+
# TODO: Use this directly in the attention operation, as a bias
283+
if exists(mask):
284+
raise NotImplementedError
285+
# mask = rearrange(mask, "b ... -> b (...)")
286+
# max_neg_value = -torch.finfo(sim.dtype).max
287+
# mask = repeat(mask, "b j -> (b h) () j", h=h)
288+
# sim.masked_fill_(~mask, max_neg_value)
289+
290+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
291+
return self.to_out(out)
292+
293+
204294
class CrossAttention(nn.Module):
205295
r"""
206296
A cross attention layer.

0 commit comments

Comments
 (0)