11import math
2- from typing import Optional
2+ import os
3+ from inspect import isfunction
4+ from typing import Any , Optional
35
46import torch
57import torch .nn .functional as F
68from torch import nn
79
10+ import xformers
11+ import xformers .ops
12+ from einops import rearrange
13+
14+
15+ _USE_MEMORY_EFFICIENT_ATTENTION = int (os .environ .get ("USE_MEMORY_EFFICIENT_ATTENTION" , 0 )) == 1
16+
17+
18+ def exists (val ):
19+ return val is not None
20+
21+
22+ def default (val , d ):
23+ if exists (val ):
24+ return val
25+ return d () if isfunction (d ) else d
26+
827
928class AttentionBlock (nn .Module ):
1029 """
@@ -177,11 +196,12 @@ def __init__(
177196 checkpoint : bool = True ,
178197 ):
179198 super ().__init__ ()
180- self .attn1 = CrossAttention (
199+ AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention
200+ self .attn1 = AttentionBuilder (
181201 query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout
182202 ) # is a self-attention
183203 self .ff = FeedForward (dim , dropout = dropout , glu = gated_ff )
184- self .attn2 = CrossAttention (
204+ self .attn2 = AttentionBuilder (
185205 query_dim = dim , context_dim = context_dim , heads = n_heads , dim_head = d_head , dropout = dropout
186206 ) # is self-attn if context is none
187207 self .norm1 = nn .LayerNorm (dim )
@@ -201,6 +221,77 @@ def forward(self, hidden_states, context=None):
201221 return hidden_states
202222
203223
224+ class MemoryEfficientCrossAttention (nn .Module ):
225+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
226+ super ().__init__ ()
227+ inner_dim = dim_head * heads
228+ context_dim = default (context_dim , query_dim )
229+
230+ self .scale = dim_head ** - 0.5
231+ self .heads = heads
232+
233+ self .to_q = nn .Linear (query_dim , inner_dim , bias = False )
234+ self .to_k = nn .Linear (context_dim , inner_dim , bias = False )
235+ self .to_v = nn .Linear (context_dim , inner_dim , bias = False )
236+
237+ self .to_out = nn .Sequential (nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout ))
238+ self .attention_op : Optional [Any ] = None
239+
240+ def _maybe_init (self , x ):
241+ """
242+ Initialize the attention operator, if required
243+ We expect the head dimension to be exposed here, meaning that
244+ x : B, Head, Length
245+ """
246+ if self .attention_op is not None :
247+ return
248+
249+ _ , K , M = x .shape
250+ try :
251+ self .attention_op = xformers .ops .AttentionOpDispatch (
252+ dtype = x .dtype ,
253+ device = x .device ,
254+ k = K ,
255+ attn_bias_type = type (None ),
256+ has_dropout = False ,
257+ kv_len = M ,
258+ q_len = M ,
259+ ).op
260+
261+ except NotImplementedError as err :
262+ raise NotImplementedError (f"Please install xformers with the flash attention / cutlass components.\n { err } " )
263+
264+ def forward (self , x , context = None , mask = None ):
265+ h = self .heads
266+
267+ q = self .to_q (x )
268+ context = default (context , x )
269+ k = self .to_k (context )
270+ v = self .to_v (context )
271+
272+ q , k , v = map (
273+ lambda t : rearrange (t , "b n (h d) -> (b h) n d" , h = h ).contiguous (),
274+ (q , k , v ),
275+ )
276+
277+ # init the attention op, if required, using the proper dimensions
278+ self ._maybe_init (q )
279+
280+ # actually compute the attention, what we cannot get enough of
281+ out = xformers .ops .memory_efficient_attention (q , k , v , attn_bias = None , op = self .attention_op )
282+
283+ # TODO: Use this directly in the attention operation, as a bias
284+ if exists (mask ):
285+ raise NotImplementedError
286+ # mask = rearrange(mask, "b ... -> b (...)")
287+ # max_neg_value = -torch.finfo(sim.dtype).max
288+ # mask = repeat(mask, "b j -> (b h) () j", h=h)
289+ # sim.masked_fill_(~mask, max_neg_value)
290+
291+ out = rearrange (out , "(b h) n d -> b n (h d)" , h = h )
292+ return self .to_out (out )
293+
294+
204295class CrossAttention (nn .Module ):
205296 r"""
206297 A cross attention layer.
0 commit comments