11import math
2- from typing import Optional
2+ import os
3+ from inspect import isfunction
4+ from typing import Any , Optional
35
46import torch
57import torch .nn .functional as F
68from torch import nn
79
10+ import xformers
11+ import xformers .ops
12+ from einops import rearrange
13+
14+
15+ _USE_MEMORY_EFFICIENT_ATTENTION = int (os .environ .get ("USE_MEMORY_EFFICIENT_ATTENTION" , 0 )) == 1
16+
17+
18+ def exists (val ):
19+ return val is not None
20+
21+
22+ def default (val , d ):
23+ if exists (val ):
24+ return val
25+ return d () if isfunction (d ) else d
26+
827
928class AttentionBlock (nn .Module ):
1029 """
@@ -177,11 +196,12 @@ def __init__(
177196 checkpoint : bool = True ,
178197 ):
179198 super ().__init__ ()
180- self .attn1 = CrossAttention (
199+ AttentionBuilder = MemoryEfficientCrossAttention if _USE_MEMORY_EFFICIENT_ATTENTION else CrossAttention
200+ self .attn1 = AttentionBuilder (
181201 query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout
182202 ) # is a self-attention
183203 self .ff = FeedForward (dim , dropout = dropout , glu = gated_ff )
184- self .attn2 = CrossAttention (
204+ self .attn2 = AttentionBuilder (
185205 query_dim = dim , context_dim = context_dim , heads = n_heads , dim_head = d_head , dropout = dropout
186206 ) # is self-attn if context is none
187207 self .norm1 = nn .LayerNorm (dim )
@@ -201,6 +221,76 @@ def forward(self, hidden_states, context=None):
201221 return hidden_states
202222
203223
224+ class MemoryEfficientCrossAttention (nn .Module ):
225+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
226+ super ().__init__ ()
227+ inner_dim = dim_head * heads
228+ context_dim = default (context_dim , query_dim )
229+
230+ self .scale = dim_head ** - 0.5
231+ self .heads = heads
232+
233+ self .to_q = nn .Linear (query_dim , inner_dim , bias = False )
234+ self .to_k = nn .Linear (context_dim , inner_dim , bias = False )
235+ self .to_v = nn .Linear (context_dim , inner_dim , bias = False )
236+
237+ self .to_out = nn .Sequential (nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout ))
238+ self .attention_op : Optional [Any ] = None
239+
240+ def _maybe_init (self , x ):
241+ """
242+ Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
243+ : B, Head, Length
244+ """
245+ if self .attention_op is not None :
246+ return
247+
248+ _ , K , M = x .shape
249+ try :
250+ self .attention_op = xformers .ops .AttentionOpDispatch (
251+ dtype = x .dtype ,
252+ device = x .device ,
253+ k = K ,
254+ attn_bias_type = type (None ),
255+ has_dropout = False ,
256+ kv_len = M ,
257+ q_len = M ,
258+ ).op
259+
260+ except NotImplementedError as err :
261+ raise NotImplementedError (f"Please install xformers with the flash attention / cutlass components.\n { err } " )
262+
263+ def forward (self , x , context = None , mask = None ):
264+ h = self .heads
265+
266+ q = self .to_q (x )
267+ context = default (context , x )
268+ k = self .to_k (context )
269+ v = self .to_v (context )
270+
271+ q , k , v = map (
272+ lambda t : rearrange (t , "b n (h d) -> (b h) n d" , h = h ).contiguous (),
273+ (q , k , v ),
274+ )
275+
276+ # init the attention op, if required, using the proper dimensions
277+ self ._maybe_init (q )
278+
279+ # actually compute the attention, what we cannot get enough of
280+ out = xformers .ops .memory_efficient_attention (q , k , v , attn_bias = None , op = self .attention_op )
281+
282+ # TODO: Use this directly in the attention operation, as a bias
283+ if exists (mask ):
284+ raise NotImplementedError
285+ # mask = rearrange(mask, "b ... -> b (...)")
286+ # max_neg_value = -torch.finfo(sim.dtype).max
287+ # mask = repeat(mask, "b j -> (b h) () j", h=h)
288+ # sim.masked_fill_(~mask, max_neg_value)
289+
290+ out = rearrange (out , "(b h) n d -> b n (h d)" , h = h )
291+ return self .to_out (out )
292+
293+
204294class CrossAttention (nn .Module ):
205295 r"""
206296 A cross attention layer.
0 commit comments