diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index a94ecd58d538..ec0434e24224 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -19,8 +19,10 @@ import numpy as np import torch +from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D +from diffusers.testing_utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -216,3 +218,100 @@ def test_downsample_with_conv_out_dim(self): output_slice = downsampled[0, -1, -3:, -3:] expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class AttentionBlockTests(unittest.TestCase): + def test_attention_block_default(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + attentionBlock = AttentionBlock( + channels=32, + num_head_channels=1, + rescale_output_factor=1.0, + eps=1e-6, + num_groups=32, + ).to(torch_device) + with torch.no_grad(): + attention_scores = attentionBlock(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class SpatialTransformerTests(unittest.TestCase): + def test_spatial_transformer_default(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + spatial_transformer_block = SpatialTransformer( + in_channels=32, + n_heads=1, + d_head=32, + dropout=0.0, + context_dim=None, + ).to(torch_device) + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_context_dim(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 64, 64, 64).to(torch_device) + spatial_transformer_block = SpatialTransformer( + in_channels=64, + n_heads=2, + d_head=32, + dropout=0.0, + context_dim=64, + ).to(torch_device) + with torch.no_grad(): + context = torch.randn(1, 4, 64).to(torch_device) + attention_scores = spatial_transformer_block(sample, context) + + assert attention_scores.shape == (1, 64, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_dropout(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + spatial_transformer_block = ( + SpatialTransformer( + in_channels=32, + n_heads=2, + d_head=16, + dropout=0.3, + context_dim=None, + ) + .to(torch_device) + .eval() + ) + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)