Skip to content

Commit e422eb3

Browse files
committed
Revert "attn refactoring"
This reverts commit 0c70c0e.
1 parent 0c70c0e commit e422eb3

File tree

1 file changed

+62
-151
lines changed

1 file changed

+62
-151
lines changed

src/diffusers/models/attention.py

Lines changed: 62 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Optional
32

43
import torch
54
import torch.nn.functional as F
@@ -11,24 +10,16 @@ class AttentionBlock(nn.Module):
1110
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1211
to the N-d case.
1312
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
14-
Uses three q, k, v linear layers to compute attention.
15-
16-
Parameters:
17-
channels (:obj:`int`): The number of channels in the input and output.
18-
num_head_channels (:obj:`int`, *optional*):
19-
The number of channels in each head. If None, then `num_heads` = 1.
20-
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21-
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22-
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
13+
Uses three q, k, v linear layers to compute attention
2314
"""
2415

2516
def __init__(
2617
self,
27-
channels: int,
28-
num_head_channels: Optional[int] = None,
29-
num_groups: int = 32,
30-
rescale_output_factor: float = 1.0,
31-
eps: float = 1e-5,
18+
channels,
19+
num_head_channels=None,
20+
num_groups=32,
21+
rescale_output_factor=1.0,
22+
eps=1e-5,
3223
):
3324
super().__init__()
3425
self.channels = channels
@@ -95,26 +86,10 @@ def forward(self, hidden_states):
9586
class SpatialTransformer(nn.Module):
9687
"""
9788
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
98-
standard transformer action. Finally, reshape to image.
99-
100-
Parameters:
101-
in_channels (:obj:`int`): The number of channels in the input and output.
102-
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103-
d_head (:obj:`int`): The number of channels in each head.
104-
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105-
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106-
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
89+
standard transformer action. Finally, reshape to image
10790
"""
10891

109-
def __init__(
110-
self,
111-
in_channels: int,
112-
n_heads: int,
113-
d_head: int,
114-
depth: int = 1,
115-
dropout: float = 0.0,
116-
context_dim: Optional[int] = None,
117-
):
92+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
11893
super().__init__()
11994
self.n_heads = n_heads
12095
self.d_head = d_head
@@ -137,44 +112,22 @@ def _set_attention_slice(self, slice_size):
137112
for block in self.transformer_blocks:
138113
block._set_attention_slice(slice_size)
139114

140-
def forward(self, hidden_states, context=None):
115+
def forward(self, x, context=None):
141116
# note: if no context is given, cross-attention defaults to self-attention
142-
batch, channel, height, weight = hidden_states.shape
143-
residual = hidden_states
144-
hidden_states = self.norm(hidden_states)
145-
hidden_states = self.proj_in(hidden_states)
146-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
117+
b, c, h, w = x.shape
118+
x_in = x
119+
x = self.norm(x)
120+
x = self.proj_in(x)
121+
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
147122
for block in self.transformer_blocks:
148-
hidden_states = block(hidden_states, context=context)
149-
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
150-
hidden_states = self.proj_out(hidden_states)
151-
return hidden_states + residual
123+
x = block(x, context=context)
124+
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
125+
x = self.proj_out(x)
126+
return x + x_in
152127

153128

154129
class BasicTransformerBlock(nn.Module):
155-
r"""
156-
A basic Transformer block.
157-
158-
Parameters:
159-
dim (:obj:`int`): The number of channels in the input and output.
160-
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161-
d_head (:obj:`int`): The number of channels in each head.
162-
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163-
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164-
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165-
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166-
"""
167-
168-
def __init__(
169-
self,
170-
dim: int,
171-
n_heads: int,
172-
d_head: int,
173-
dropout=0.0,
174-
context_dim: Optional[int] = None,
175-
gated_ff: bool = True,
176-
checkpoint: bool = True,
177-
):
130+
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
178131
super().__init__()
179132
self.attn1 = CrossAttention(
180133
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
@@ -192,30 +145,15 @@ def _set_attention_slice(self, slice_size):
192145
self.attn1._slice_size = slice_size
193146
self.attn2._slice_size = slice_size
194147

195-
def forward(self, hidden_states, context=None):
196-
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
197-
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
198-
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
199-
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
200-
return hidden_states
148+
def forward(self, x, context=None):
149+
x = self.attn1(self.norm1(x)) + x
150+
x = self.attn2(self.norm2(x), context=context) + x
151+
x = self.ff(self.norm3(x)) + x
152+
return x
201153

202154

203155
class CrossAttention(nn.Module):
204-
r"""
205-
A cross attention layer.
206-
207-
Parameters:
208-
query_dim (:obj:`int`): The number of channels in the query.
209-
context_dim (:obj:`int`, *optional*):
210-
The number of channels in the context. If not given, defaults to `query_dim`.
211-
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212-
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213-
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214-
"""
215-
216-
def __init__(
217-
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
218-
):
156+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
219157
super().__init__()
220158
inner_dim = dim_head * heads
221159
context_dim = context_dim if context_dim is not None else query_dim
@@ -236,104 +174,77 @@ def __init__(
236174
def reshape_heads_to_batch_dim(self, tensor):
237175
batch_size, seq_len, dim = tensor.shape
238176
head_size = self.heads
239-
tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
240-
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
241-
return tensor3
177+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
178+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
179+
return tensor
242180

243181
def reshape_batch_dim_to_heads(self, tensor):
244182
batch_size, seq_len, dim = tensor.shape
245183
head_size = self.heads
246-
tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
247-
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
248-
return tensor3
184+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
185+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
186+
return tensor
249187

250-
def forward(self, hidden_states, context=None, mask=None):
251-
batch_size, sequence_length, dim = hidden_states.shape
188+
def forward(self, x, context=None, mask=None):
189+
batch_size, sequence_length, dim = x.shape
252190

253-
query = self.to_q(hidden_states)
254-
context = context if context is not None else hidden_states
255-
key = self.to_k(context)
256-
value = self.to_v(context)
191+
q = self.to_q(x)
192+
context = context if context is not None else x
193+
k = self.to_k(context)
194+
v = self.to_v(context)
257195

258-
query = self.reshape_heads_to_batch_dim(query)
259-
key = self.reshape_heads_to_batch_dim(key)
260-
value = self.reshape_heads_to_batch_dim(value)
196+
q = self.reshape_heads_to_batch_dim(q)
197+
k = self.reshape_heads_to_batch_dim(k)
198+
v = self.reshape_heads_to_batch_dim(v)
261199

262200
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
263201

264202
# attention, what we cannot get enough of
265-
hidden_states = self._attention(query, key, value, sequence_length, dim)
203+
hidden_states = self._attention(q, k, v, sequence_length, dim)
266204

267205
return self.to_out(hidden_states)
268206

269207
def _attention(self, query, key, value, sequence_length, dim):
270208
batch_size_attention = query.shape[0]
271-
# hidden_states = torch.zeros(
272-
# (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
273-
# )
274-
slice_size = self._slice_size if self._slice_size is not None else batch_size_attention
275-
# for i in range(hidden_states.shape[0] // slice_size):
276-
# start_idx = i * slice_size
277-
# end_idx = (i + 1) * slice_size
278-
# qslice = query[start_idx:end_idx]
279-
qslice = query
280-
# kslice = key[start_idx:end_idx].transpose(1, 2)
281-
kslice = key.transpose(1, 2)
282-
attn_slice = torch.matmul(qslice, kslice) * self.scale
283-
attn_slice = attn_slice.softmax(dim=-1)
284-
# vslice = value[start_idx:end_idx]
285-
vslice = value
286-
hidden_states = torch.matmul(attn_slice, vslice)
287-
288-
289-
# hidden_states = torch.cat(attn_slices, dim=0)
290-
209+
hidden_states = torch.zeros(
210+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
211+
)
212+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
213+
for i in range(hidden_states.shape[0] // slice_size):
214+
start_idx = i * slice_size
215+
end_idx = (i + 1) * slice_size
216+
attn_slice = (
217+
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
218+
)
219+
attn_slice = attn_slice.softmax(dim=-1)
220+
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
221+
222+
hidden_states[start_idx:end_idx] = attn_slice
291223

292224
# reshape hidden_states
293225
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
294226
return hidden_states
295227

296228

297229
class FeedForward(nn.Module):
298-
r"""
299-
A feed-forward layer.
300-
301-
Parameters:
302-
dim (:obj:`int`): The number of channels in the input.
303-
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
304-
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
305-
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
306-
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
307-
"""
308-
309-
def __init__(
310-
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
311-
):
230+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
312231
super().__init__()
313232
inner_dim = int(dim * mult)
314233
dim_out = dim_out if dim_out is not None else dim
315234
project_in = GEGLU(dim, inner_dim)
316235

317236
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
318237

319-
def forward(self, hidden_states):
320-
return self.net(hidden_states)
238+
def forward(self, x):
239+
return self.net(x)
321240

322241

323242
# feedforward
324243
class GEGLU(nn.Module):
325-
r"""
326-
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
327-
328-
Parameters:
329-
dim_in (:obj:`int`): The number of channels in the input.
330-
dim_out (:obj:`int`): The number of channels in the output.
331-
"""
332-
333-
def __init__(self, dim_in: int, dim_out: int):
244+
def __init__(self, dim_in, dim_out):
334245
super().__init__()
335246
self.proj = nn.Linear(dim_in, dim_out * 2)
336247

337-
def forward(self, hidden_states):
338-
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
339-
return hidden_states * F.gelu(gate)
248+
def forward(self, x):
249+
x, gate = self.proj(x).chunk(2, dim=-1)
250+
return x * F.gelu(gate)

0 commit comments

Comments
 (0)