|
19 | 19 | import numpy as np |
20 | 20 | import torch |
21 | 21 |
|
| 22 | +from diffusers.models.attention import AttentionBlock, SpatialTransformer |
22 | 23 | from diffusers.models.embeddings import get_timestep_embedding |
23 | 24 | from diffusers.models.resnet import Downsample2D, Upsample2D |
24 | 25 |
|
@@ -216,3 +217,45 @@ def test_downsample_with_conv_out_dim(self): |
216 | 217 | output_slice = downsampled[0, -1, -3:, -3:] |
217 | 218 | expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) |
218 | 219 | assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 220 | + |
| 221 | + |
| 222 | +class AttentionBlockTests(unittest.TestCase): |
| 223 | + def test_attention_block_default(self): |
| 224 | + torch.manual_seed(0) |
| 225 | + sample = torch.randn(1, 32, 64, 64) |
| 226 | + attentionBlock = AttentionBlock( |
| 227 | + channels=32, |
| 228 | + num_head_channels=1, |
| 229 | + rescale_output_factor=1.0, |
| 230 | + eps=1e-6, |
| 231 | + num_groups=32, |
| 232 | + ) |
| 233 | + with torch.no_grad(): |
| 234 | + attention_scores = attentionBlock(sample) |
| 235 | + |
| 236 | + assert attention_scores.shape == (1, 32, 64, 64) |
| 237 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 238 | + |
| 239 | + expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427]) |
| 240 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1) |
| 241 | + |
| 242 | + |
| 243 | +class SpatialTransformerTests(unittest.TestCase): |
| 244 | + def test_spatial_transformer_default(self): |
| 245 | + torch.manual_seed(0) |
| 246 | + sample = torch.randn(1, 32, 64, 64) |
| 247 | + spatialTransformerBlock = SpatialTransformer( |
| 248 | + in_channels=32, |
| 249 | + n_heads=1, |
| 250 | + d_head=32, |
| 251 | + dropout=0.0, |
| 252 | + context_dim=None, |
| 253 | + ) |
| 254 | + with torch.no_grad(): |
| 255 | + attention_scores = spatialTransformerBlock(sample) |
| 256 | + |
| 257 | + assert attention_scores.shape == (1, 32, 64, 64) |
| 258 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 259 | + |
| 260 | + expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201]) |
| 261 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1) |
0 commit comments