11import math
2- from typing import Optional
32
43import torch
54import torch .nn .functional as F
@@ -11,24 +10,16 @@ class AttentionBlock(nn.Module):
1110 An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1211 to the N-d case.
1312 https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
14- Uses three q, k, v linear layers to compute attention.
15-
16- Parameters:
17- channels (:obj:`int`): The number of channels in the input and output.
18- num_head_channels (:obj:`int`, *optional*):
19- The number of channels in each head. If None, then `num_heads` = 1.
20- num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21- rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22- eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
13+ Uses three q, k, v linear layers to compute attention
2314 """
2415
2516 def __init__ (
2617 self ,
27- channels : int ,
28- num_head_channels : Optional [ int ] = None ,
29- num_groups : int = 32 ,
30- rescale_output_factor : float = 1.0 ,
31- eps : float = 1e-5 ,
18+ channels ,
19+ num_head_channels = None ,
20+ num_groups = 32 ,
21+ rescale_output_factor = 1.0 ,
22+ eps = 1e-5 ,
3223 ):
3324 super ().__init__ ()
3425 self .channels = channels
@@ -95,26 +86,10 @@ def forward(self, hidden_states):
9586class SpatialTransformer (nn .Module ):
9687 """
9788 Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
98- standard transformer action. Finally, reshape to image.
99-
100- Parameters:
101- in_channels (:obj:`int`): The number of channels in the input and output.
102- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103- d_head (:obj:`int`): The number of channels in each head.
104- depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105- dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106- context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
89+ standard transformer action. Finally, reshape to image
10790 """
10891
109- def __init__ (
110- self ,
111- in_channels : int ,
112- n_heads : int ,
113- d_head : int ,
114- depth : int = 1 ,
115- dropout : float = 0.0 ,
116- context_dim : Optional [int ] = None ,
117- ):
92+ def __init__ (self , in_channels , n_heads , d_head , depth = 1 , dropout = 0.0 , context_dim = None ):
11893 super ().__init__ ()
11994 self .n_heads = n_heads
12095 self .d_head = d_head
@@ -137,44 +112,22 @@ def _set_attention_slice(self, slice_size):
137112 for block in self .transformer_blocks :
138113 block ._set_attention_slice (slice_size )
139114
140- def forward (self , hidden_states , context = None ):
115+ def forward (self , x , context = None ):
141116 # note: if no context is given, cross-attention defaults to self-attention
142- batch , channel , height , weight = hidden_states .shape
143- residual = hidden_states
144- hidden_states = self .norm (hidden_states )
145- hidden_states = self .proj_in (hidden_states )
146- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , channel )
117+ b , c , h , w = x .shape
118+ x_in = x
119+ x = self .norm (x )
120+ x = self .proj_in (x )
121+ x = x .permute (0 , 2 , 3 , 1 ).reshape (b , h * w , c )
147122 for block in self .transformer_blocks :
148- hidden_states = block (hidden_states , context = context )
149- hidden_states = hidden_states .reshape (batch , height , weight , channel ).permute (0 , 3 , 1 , 2 )
150- hidden_states = self .proj_out (hidden_states )
151- return hidden_states + residual
123+ x = block (x , context = context )
124+ x = x .reshape (b , h , w , c ).permute (0 , 3 , 1 , 2 )
125+ x = self .proj_out (x )
126+ return x + x_in
152127
153128
154129class BasicTransformerBlock (nn .Module ):
155- r"""
156- A basic Transformer block.
157-
158- Parameters:
159- dim (:obj:`int`): The number of channels in the input and output.
160- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161- d_head (:obj:`int`): The number of channels in each head.
162- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163- context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164- gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165- checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166- """
167-
168- def __init__ (
169- self ,
170- dim : int ,
171- n_heads : int ,
172- d_head : int ,
173- dropout = 0.0 ,
174- context_dim : Optional [int ] = None ,
175- gated_ff : bool = True ,
176- checkpoint : bool = True ,
177- ):
130+ def __init__ (self , dim , n_heads , d_head , dropout = 0.0 , context_dim = None , gated_ff = True , checkpoint = True ):
178131 super ().__init__ ()
179132 self .attn1 = CrossAttention (
180133 query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout
@@ -192,30 +145,15 @@ def _set_attention_slice(self, slice_size):
192145 self .attn1 ._slice_size = slice_size
193146 self .attn2 ._slice_size = slice_size
194147
195- def forward (self , hidden_states , context = None ):
196- hidden_states = hidden_states .contiguous () if hidden_states .device .type == "mps" else hidden_states
197- hidden_states = self .attn1 (self .norm1 (hidden_states )) + hidden_states
198- hidden_states = self .attn2 (self .norm2 (hidden_states ), context = context ) + hidden_states
199- hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
200- return hidden_states
148+ def forward (self , x , context = None ):
149+ x = self .attn1 (self .norm1 (x )) + x
150+ x = self .attn2 (self .norm2 (x ), context = context ) + x
151+ x = self .ff (self .norm3 (x )) + x
152+ return x
201153
202154
203155class CrossAttention (nn .Module ):
204- r"""
205- A cross attention layer.
206-
207- Parameters:
208- query_dim (:obj:`int`): The number of channels in the query.
209- context_dim (:obj:`int`, *optional*):
210- The number of channels in the context. If not given, defaults to `query_dim`.
211- heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212- dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214- """
215-
216- def __init__ (
217- self , query_dim : int , context_dim : Optional [int ] = None , heads : int = 8 , dim_head : int = 64 , dropout : int = 0.0
218- ):
156+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
219157 super ().__init__ ()
220158 inner_dim = dim_head * heads
221159 context_dim = context_dim if context_dim is not None else query_dim
@@ -236,104 +174,77 @@ def __init__(
236174 def reshape_heads_to_batch_dim (self , tensor ):
237175 batch_size , seq_len , dim = tensor .shape
238176 head_size = self .heads
239- tensor2 = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
240- tensor3 = tensor2 .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
241- return tensor3
177+ tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
178+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
179+ return tensor
242180
243181 def reshape_batch_dim_to_heads (self , tensor ):
244182 batch_size , seq_len , dim = tensor .shape
245183 head_size = self .heads
246- tensor2 = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
247- tensor3 = tensor2 .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
248- return tensor3
184+ tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
185+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
186+ return tensor
249187
250- def forward (self , hidden_states , context = None , mask = None ):
251- batch_size , sequence_length , dim = hidden_states .shape
188+ def forward (self , x , context = None , mask = None ):
189+ batch_size , sequence_length , dim = x .shape
252190
253- query = self .to_q (hidden_states )
254- context = context if context is not None else hidden_states
255- key = self .to_k (context )
256- value = self .to_v (context )
191+ q = self .to_q (x )
192+ context = context if context is not None else x
193+ k = self .to_k (context )
194+ v = self .to_v (context )
257195
258- query = self .reshape_heads_to_batch_dim (query )
259- key = self .reshape_heads_to_batch_dim (key )
260- value = self .reshape_heads_to_batch_dim (value )
196+ q = self .reshape_heads_to_batch_dim (q )
197+ k = self .reshape_heads_to_batch_dim (k )
198+ v = self .reshape_heads_to_batch_dim (v )
261199
262200 # TODO(PVP) - mask is currently never used. Remember to re-implement when used
263201
264202 # attention, what we cannot get enough of
265- hidden_states = self ._attention (query , key , value , sequence_length , dim )
203+ hidden_states = self ._attention (q , k , v , sequence_length , dim )
266204
267205 return self .to_out (hidden_states )
268206
269207 def _attention (self , query , key , value , sequence_length , dim ):
270208 batch_size_attention = query .shape [0 ]
271- # hidden_states = torch.zeros(
272- # (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
273- # )
274- slice_size = self ._slice_size if self ._slice_size is not None else batch_size_attention
275- # for i in range(hidden_states.shape[0] // slice_size):
276- # start_idx = i * slice_size
277- # end_idx = (i + 1) * slice_size
278- # qslice = query[start_idx:end_idx]
279- qslice = query
280- # kslice = key[start_idx:end_idx].transpose(1, 2)
281- kslice = key .transpose (1 , 2 )
282- attn_slice = torch .matmul (qslice , kslice ) * self .scale
283- attn_slice = attn_slice .softmax (dim = - 1 )
284- # vslice = value[start_idx:end_idx]
285- vslice = value
286- hidden_states = torch .matmul (attn_slice , vslice )
287-
288-
289- # hidden_states = torch.cat(attn_slices, dim=0)
290-
209+ hidden_states = torch .zeros (
210+ (batch_size_attention , sequence_length , dim // self .heads ), device = query .device , dtype = query .dtype
211+ )
212+ slice_size = self ._slice_size if self ._slice_size is not None else hidden_states .shape [0 ]
213+ for i in range (hidden_states .shape [0 ] // slice_size ):
214+ start_idx = i * slice_size
215+ end_idx = (i + 1 ) * slice_size
216+ attn_slice = (
217+ torch .einsum ("b i d, b j d -> b i j" , query [start_idx :end_idx ], key [start_idx :end_idx ]) * self .scale
218+ )
219+ attn_slice = attn_slice .softmax (dim = - 1 )
220+ attn_slice = torch .einsum ("b i j, b j d -> b i d" , attn_slice , value [start_idx :end_idx ])
221+
222+ hidden_states [start_idx :end_idx ] = attn_slice
291223
292224 # reshape hidden_states
293225 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
294226 return hidden_states
295227
296228
297229class FeedForward (nn .Module ):
298- r"""
299- A feed-forward layer.
300-
301- Parameters:
302- dim (:obj:`int`): The number of channels in the input.
303- dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
304- mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
305- glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
306- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
307- """
308-
309- def __init__ (
310- self , dim : int , dim_out : Optional [int ] = None , mult : int = 4 , glu : bool = False , dropout : float = 0.0
311- ):
230+ def __init__ (self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0.0 ):
312231 super ().__init__ ()
313232 inner_dim = int (dim * mult )
314233 dim_out = dim_out if dim_out is not None else dim
315234 project_in = GEGLU (dim , inner_dim )
316235
317236 self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
318237
319- def forward (self , hidden_states ):
320- return self .net (hidden_states )
238+ def forward (self , x ):
239+ return self .net (x )
321240
322241
323242# feedforward
324243class GEGLU (nn .Module ):
325- r"""
326- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
327-
328- Parameters:
329- dim_in (:obj:`int`): The number of channels in the input.
330- dim_out (:obj:`int`): The number of channels in the output.
331- """
332-
333- def __init__ (self , dim_in : int , dim_out : int ):
244+ def __init__ (self , dim_in , dim_out ):
334245 super ().__init__ ()
335246 self .proj = nn .Linear (dim_in , dim_out * 2 )
336247
337- def forward (self , hidden_states ):
338- hidden_states , gate = self .proj (hidden_states ).chunk (2 , dim = - 1 )
339- return hidden_states * F .gelu (gate )
248+ def forward (self , x ):
249+ x , gate = self .proj (x ).chunk (2 , dim = - 1 )
250+ return x * F .gelu (gate )
0 commit comments