Skip to content

Commit f73ca90

Browse files
authored
[Tests] Test attention.py (#368)
* add test for AttentionBlock, SpatialTransformer * add context_dim, handle device * removed dropout test * fixes, add dropout test
1 parent 37c9d78 commit f73ca90

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

tests/test_layers_utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import numpy as np
2020
import torch
2121

22+
from diffusers.models.attention import AttentionBlock, SpatialTransformer
2223
from diffusers.models.embeddings import get_timestep_embedding
2324
from diffusers.models.resnet import Downsample2D, Upsample2D
25+
from diffusers.testing_utils import torch_device
2426

2527

2628
torch.backends.cuda.matmul.allow_tf32 = False
@@ -216,3 +218,100 @@ def test_downsample_with_conv_out_dim(self):
216218
output_slice = downsampled[0, -1, -3:, -3:]
217219
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
218220
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
221+
222+
223+
class AttentionBlockTests(unittest.TestCase):
224+
def test_attention_block_default(self):
225+
torch.manual_seed(0)
226+
if torch.cuda.is_available():
227+
torch.cuda.manual_seed_all(0)
228+
229+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
230+
attentionBlock = AttentionBlock(
231+
channels=32,
232+
num_head_channels=1,
233+
rescale_output_factor=1.0,
234+
eps=1e-6,
235+
num_groups=32,
236+
).to(torch_device)
237+
with torch.no_grad():
238+
attention_scores = attentionBlock(sample)
239+
240+
assert attention_scores.shape == (1, 32, 64, 64)
241+
output_slice = attention_scores[0, -1, -3:, -3:]
242+
243+
expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427])
244+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
245+
246+
247+
class SpatialTransformerTests(unittest.TestCase):
248+
def test_spatial_transformer_default(self):
249+
torch.manual_seed(0)
250+
if torch.cuda.is_available():
251+
torch.cuda.manual_seed_all(0)
252+
253+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
254+
spatial_transformer_block = SpatialTransformer(
255+
in_channels=32,
256+
n_heads=1,
257+
d_head=32,
258+
dropout=0.0,
259+
context_dim=None,
260+
).to(torch_device)
261+
with torch.no_grad():
262+
attention_scores = spatial_transformer_block(sample)
263+
264+
assert attention_scores.shape == (1, 32, 64, 64)
265+
output_slice = attention_scores[0, -1, -3:, -3:]
266+
267+
expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201])
268+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
269+
270+
def test_spatial_transformer_context_dim(self):
271+
torch.manual_seed(0)
272+
if torch.cuda.is_available():
273+
torch.cuda.manual_seed_all(0)
274+
275+
sample = torch.randn(1, 64, 64, 64).to(torch_device)
276+
spatial_transformer_block = SpatialTransformer(
277+
in_channels=64,
278+
n_heads=2,
279+
d_head=32,
280+
dropout=0.0,
281+
context_dim=64,
282+
).to(torch_device)
283+
with torch.no_grad():
284+
context = torch.randn(1, 4, 64).to(torch_device)
285+
attention_scores = spatial_transformer_block(sample, context)
286+
287+
assert attention_scores.shape == (1, 64, 64, 64)
288+
output_slice = attention_scores[0, -1, -3:, -3:]
289+
290+
expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471])
291+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
292+
293+
def test_spatial_transformer_dropout(self):
294+
torch.manual_seed(0)
295+
if torch.cuda.is_available():
296+
torch.cuda.manual_seed_all(0)
297+
298+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
299+
spatial_transformer_block = (
300+
SpatialTransformer(
301+
in_channels=32,
302+
n_heads=2,
303+
d_head=16,
304+
dropout=0.3,
305+
context_dim=None,
306+
)
307+
.to(torch_device)
308+
.eval()
309+
)
310+
with torch.no_grad():
311+
attention_scores = spatial_transformer_block(sample)
312+
313+
assert attention_scores.shape == (1, 32, 64, 64)
314+
output_slice = attention_scores[0, -1, -3:, -3:]
315+
316+
expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091])
317+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

0 commit comments

Comments
 (0)