Skip to content

Commit fb4e372

Browse files
2x speedup using memory efficient attention
1 parent 761f029 commit fb4e372

File tree

2 files changed

+97
-4
lines changed

2 files changed

+97
-4
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def run(self):
182182
else:
183183
extras["flax"] = deps_list("jax", "jaxlib", "flax")
184184

185-
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
185+
extras["dev"] = (
186+
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
187+
)
186188

187189
install_requires = [
188190
deps["importlib_metadata"],

src/diffusers/models/attention.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
import math
2-
from typing import Optional
2+
import os
3+
from inspect import isfunction
4+
from typing import Any, Optional
35

46
import torch
57
import torch.nn.functional as F
68
from torch import nn
79

10+
import xformers
11+
import xformers.ops
12+
from einops import rearrange
13+
14+
15+
_USE_MEMORY_EFFICIENT_ATTENTION = int(os.environ.get("USE_MEMORY_EFFICIENT_ATTENTION", 0)) == 1
16+
17+
18+
def exists(val):
19+
return val is not None
20+
21+
22+
def default(val, d):
23+
if exists(val):
24+
return val
25+
return d() if isfunction(d) else d
26+
827

928
class AttentionBlock(nn.Module):
1029
"""
@@ -177,11 +196,12 @@ def __init__(
177196
checkpoint: bool = True,
178197
):
179198
super().__init__()
180-
self.attn1 = CrossAttention(
199+
AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention
200+
self.attn1 = AttentionBuilder(
181201
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
182202
) # is a self-attention
183203
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
184-
self.attn2 = CrossAttention(
204+
self.attn2 = AttentionBuilder(
185205
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
186206
) # is self-attn if context is none
187207
self.norm1 = nn.LayerNorm(dim)
@@ -201,6 +221,77 @@ def forward(self, hidden_states, context=None):
201221
return hidden_states
202222

203223

224+
class MemoryEfficientCrossAttention(nn.Module):
225+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
226+
super().__init__()
227+
inner_dim = dim_head * heads
228+
context_dim = default(context_dim, query_dim)
229+
230+
self.scale = dim_head**-0.5
231+
self.heads = heads
232+
233+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
234+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
235+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
236+
237+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
238+
self.attention_op: Optional[Any] = None
239+
240+
def _maybe_init(self, x):
241+
"""
242+
Initialize the attention operator, if required
243+
We expect the head dimension to be exposed here, meaning that
244+
x : B, Head, Length
245+
"""
246+
if self.attention_op is not None:
247+
return
248+
249+
_, K, M = x.shape
250+
try:
251+
self.attention_op = xformers.ops.AttentionOpDispatch(
252+
dtype=x.dtype,
253+
device=x.device,
254+
k=K,
255+
attn_bias_type=type(None),
256+
has_dropout=False,
257+
kv_len=M,
258+
q_len=M,
259+
).op
260+
261+
except NotImplementedError as err:
262+
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
263+
264+
def forward(self, x, context=None, mask=None):
265+
h = self.heads
266+
267+
q = self.to_q(x)
268+
context = default(context, x)
269+
k = self.to_k(context)
270+
v = self.to_v(context)
271+
272+
q, k, v = map(
273+
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h).contiguous(),
274+
(q, k, v),
275+
)
276+
277+
# init the attention op, if required, using the proper dimensions
278+
self._maybe_init(q)
279+
280+
# actually compute the attention, what we cannot get enough of
281+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
282+
283+
# TODO: Use this directly in the attention operation, as a bias
284+
if exists(mask):
285+
raise NotImplementedError
286+
# mask = rearrange(mask, "b ... -> b (...)")
287+
# max_neg_value = -torch.finfo(sim.dtype).max
288+
# mask = repeat(mask, "b j -> (b h) () j", h=h)
289+
# sim.masked_fill_(~mask, max_neg_value)
290+
291+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
292+
return self.to_out(out)
293+
294+
204295
class CrossAttention(nn.Module):
205296
r"""
206297
A cross attention layer.

0 commit comments

Comments
 (0)