Skip to content

Commit 425192f

Browse files
Make sure VAE attention works with Torch 2_0 (#3200)
* Make sure attention works with Torch 2_0 * make style * Fix more
1 parent 9965cb5 commit 425192f

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

src/diffusers/models/attention.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
self.channels = channels
6161

6262
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63-
self.num_head_size = num_head_channels
6463
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
6564

6665
# define q,k,v as linear layers
@@ -74,18 +73,25 @@ def __init__(
7473
self._use_memory_efficient_attention_xformers = False
7574
self._attention_op = None
7675

77-
def reshape_heads_to_batch_dim(self, tensor):
76+
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
7877
batch_size, seq_len, dim = tensor.shape
7978
head_size = self.num_heads
8079
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
80+
tensor = tensor.permute(0, 2, 1, 3)
81+
if merge_head_and_batch:
82+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
8283
return tensor
8384

84-
def reshape_batch_dim_to_heads(self, tensor):
85-
batch_size, seq_len, dim = tensor.shape
85+
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
8686
head_size = self.num_heads
87-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
87+
88+
if unmerge_head_and_batch:
89+
batch_size, seq_len, dim = tensor.shape
90+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
91+
else:
92+
batch_size, _, seq_len, dim = tensor.shape
93+
94+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
8995
return tensor
9096

9197
def set_use_memory_efficient_attention_xformers(
@@ -134,14 +140,25 @@ def forward(self, hidden_states):
134140

135141
scale = 1 / math.sqrt(self.channels / self.num_heads)
136142

137-
query_proj = self.reshape_heads_to_batch_dim(query_proj)
138-
key_proj = self.reshape_heads_to_batch_dim(key_proj)
139-
value_proj = self.reshape_heads_to_batch_dim(value_proj)
143+
use_torch_2_0_attn = (
144+
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
145+
)
146+
147+
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
148+
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
149+
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
140150

141151
if self._use_memory_efficient_attention_xformers:
142152
# Memory efficient attention
143153
hidden_states = xformers.ops.memory_efficient_attention(
144-
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
154+
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
155+
)
156+
hidden_states = hidden_states.to(query_proj.dtype)
157+
elif use_torch_2_0_attn:
158+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
159+
# TODO: add support for attn.scale when we move to Torch 2.1
160+
hidden_states = F.scaled_dot_product_attention(
161+
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
145162
)
146163
hidden_states = hidden_states.to(query_proj.dtype)
147164
else:
@@ -162,7 +179,7 @@ def forward(self, hidden_states):
162179
hidden_states = torch.bmm(attention_probs, value_proj)
163180

164181
# reshape hidden_states
165-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
182+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
166183

167184
# compute next hidden_states
168185
hidden_states = self.proj_attn(hidden_states)

tests/models/test_models_vae.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,40 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
319319

320320
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
321321

322+
@parameterized.expand([13, 16, 27])
323+
@require_torch_gpu
324+
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
325+
model = self.get_sd_vae_model(fp16=True)
326+
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
327+
328+
with torch.no_grad():
329+
sample = model.decode(encoding).sample
330+
331+
model.enable_xformers_memory_efficient_attention()
332+
with torch.no_grad():
333+
sample_2 = model.decode(encoding).sample
334+
335+
assert list(sample.shape) == [3, 3, 512, 512]
336+
337+
assert torch_all_close(sample, sample_2, atol=1e-1)
338+
339+
@parameterized.expand([13, 16, 37])
340+
@require_torch_gpu
341+
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
342+
model = self.get_sd_vae_model()
343+
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
344+
345+
with torch.no_grad():
346+
sample = model.decode(encoding).sample
347+
348+
model.enable_xformers_memory_efficient_attention()
349+
with torch.no_grad():
350+
sample_2 = model.decode(encoding).sample
351+
352+
assert list(sample.shape) == [3, 3, 512, 512]
353+
354+
assert torch_all_close(sample, sample_2, atol=1e-2)
355+
322356
@parameterized.expand(
323357
[
324358
# fmt: off

0 commit comments

Comments
 (0)