|
19 | 19 | import numpy as np |
20 | 20 | import torch |
21 | 21 |
|
| 22 | +from diffusers.models.attention import AttentionBlock, SpatialTransformer |
22 | 23 | from diffusers.models.embeddings import get_timestep_embedding |
23 | 24 | from diffusers.models.resnet import Downsample2D, Upsample2D |
| 25 | +from diffusers.testing_utils import torch_device |
24 | 26 |
|
25 | 27 |
|
26 | 28 | torch.backends.cuda.matmul.allow_tf32 = False |
@@ -216,3 +218,100 @@ def test_downsample_with_conv_out_dim(self): |
216 | 218 | output_slice = downsampled[0, -1, -3:, -3:] |
217 | 219 | expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) |
218 | 220 | assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 221 | + |
| 222 | + |
| 223 | +class AttentionBlockTests(unittest.TestCase): |
| 224 | + def test_attention_block_default(self): |
| 225 | + torch.manual_seed(0) |
| 226 | + if torch.cuda.is_available(): |
| 227 | + torch.cuda.manual_seed_all(0) |
| 228 | + |
| 229 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 230 | + attentionBlock = AttentionBlock( |
| 231 | + channels=32, |
| 232 | + num_head_channels=1, |
| 233 | + rescale_output_factor=1.0, |
| 234 | + eps=1e-6, |
| 235 | + num_groups=32, |
| 236 | + ).to(torch_device) |
| 237 | + with torch.no_grad(): |
| 238 | + attention_scores = attentionBlock(sample) |
| 239 | + |
| 240 | + assert attention_scores.shape == (1, 32, 64, 64) |
| 241 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 242 | + |
| 243 | + expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427]) |
| 244 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 245 | + |
| 246 | + |
| 247 | +class SpatialTransformerTests(unittest.TestCase): |
| 248 | + def test_spatial_transformer_default(self): |
| 249 | + torch.manual_seed(0) |
| 250 | + if torch.cuda.is_available(): |
| 251 | + torch.cuda.manual_seed_all(0) |
| 252 | + |
| 253 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 254 | + spatial_transformer_block = SpatialTransformer( |
| 255 | + in_channels=32, |
| 256 | + n_heads=1, |
| 257 | + d_head=32, |
| 258 | + dropout=0.0, |
| 259 | + context_dim=None, |
| 260 | + ).to(torch_device) |
| 261 | + with torch.no_grad(): |
| 262 | + attention_scores = spatial_transformer_block(sample) |
| 263 | + |
| 264 | + assert attention_scores.shape == (1, 32, 64, 64) |
| 265 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 266 | + |
| 267 | + expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201]) |
| 268 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 269 | + |
| 270 | + def test_spatial_transformer_context_dim(self): |
| 271 | + torch.manual_seed(0) |
| 272 | + if torch.cuda.is_available(): |
| 273 | + torch.cuda.manual_seed_all(0) |
| 274 | + |
| 275 | + sample = torch.randn(1, 64, 64, 64).to(torch_device) |
| 276 | + spatial_transformer_block = SpatialTransformer( |
| 277 | + in_channels=64, |
| 278 | + n_heads=2, |
| 279 | + d_head=32, |
| 280 | + dropout=0.0, |
| 281 | + context_dim=64, |
| 282 | + ).to(torch_device) |
| 283 | + with torch.no_grad(): |
| 284 | + context = torch.randn(1, 4, 64).to(torch_device) |
| 285 | + attention_scores = spatial_transformer_block(sample, context) |
| 286 | + |
| 287 | + assert attention_scores.shape == (1, 64, 64, 64) |
| 288 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 289 | + |
| 290 | + expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471]) |
| 291 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 292 | + |
| 293 | + def test_spatial_transformer_dropout(self): |
| 294 | + torch.manual_seed(0) |
| 295 | + if torch.cuda.is_available(): |
| 296 | + torch.cuda.manual_seed_all(0) |
| 297 | + |
| 298 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 299 | + spatial_transformer_block = ( |
| 300 | + SpatialTransformer( |
| 301 | + in_channels=32, |
| 302 | + n_heads=2, |
| 303 | + d_head=16, |
| 304 | + dropout=0.3, |
| 305 | + context_dim=None, |
| 306 | + ) |
| 307 | + .to(torch_device) |
| 308 | + .eval() |
| 309 | + ) |
| 310 | + with torch.no_grad(): |
| 311 | + attention_scores = spatial_transformer_block(sample) |
| 312 | + |
| 313 | + assert attention_scores.shape == (1, 32, 64, 64) |
| 314 | + output_slice = attention_scores[0, -1, -3:, -3:] |
| 315 | + |
| 316 | + expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091]) |
| 317 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
0 commit comments