Skip to content

Commit 20e198a

Browse files
committed
fixes, add dropout test
1 parent db75e7d commit 20e198a

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

tests/test_layers_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,15 @@ def test_spatial_transformer_default(self):
251251
torch.cuda.manual_seed_all(0)
252252

253253
sample = torch.randn(1, 32, 64, 64).to(torch_device)
254-
spatialTransformerBlock = SpatialTransformer(
254+
spatial_transformer_block = SpatialTransformer(
255255
in_channels=32,
256256
n_heads=1,
257257
d_head=32,
258258
dropout=0.0,
259259
context_dim=None,
260260
).to(torch_device)
261261
with torch.no_grad():
262-
attention_scores = spatialTransformerBlock(sample)
262+
attention_scores = spatial_transformer_block(sample)
263263

264264
assert attention_scores.shape == (1, 32, 64, 64)
265265
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -272,20 +272,46 @@ def test_spatial_transformer_context_dim(self):
272272
if torch.cuda.is_available():
273273
torch.cuda.manual_seed_all(0)
274274

275-
torch.manual_seed(0)
276275
sample = torch.randn(1, 64, 64, 64).to(torch_device)
277-
spatialTransformerBlock = SpatialTransformer(
276+
spatial_transformer_block = SpatialTransformer(
278277
in_channels=64,
279278
n_heads=2,
280279
d_head=32,
281280
dropout=0.0,
282281
context_dim=64,
283282
).to(torch_device)
284283
with torch.no_grad():
285-
attention_scores = spatialTransformerBlock(sample)
284+
context = torch.randn(1, 4, 64).to(torch_device)
285+
attention_scores = spatial_transformer_block(sample, context)
286286

287287
assert attention_scores.shape == (1, 64, 64, 64)
288288
output_slice = attention_scores[0, -1, -3:, -3:]
289289

290-
expected_slice = torch.tensor([-0.0278, -0.7288, -2.2825, -2.0128, 1.4513, 0.2600, -0.2489, -1.4279, 0.1277])
290+
expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471])
291+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
292+
293+
def test_spatial_transformer_dropout(self):
294+
torch.manual_seed(0)
295+
if torch.cuda.is_available():
296+
torch.cuda.manual_seed_all(0)
297+
298+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
299+
spatial_transformer_block = (
300+
SpatialTransformer(
301+
in_channels=32,
302+
n_heads=2,
303+
d_head=16,
304+
dropout=0.3,
305+
context_dim=None,
306+
)
307+
.to(torch_device)
308+
.eval()
309+
)
310+
with torch.no_grad():
311+
attention_scores = spatial_transformer_block(sample)
312+
313+
assert attention_scores.shape == (1, 32, 64, 64)
314+
output_slice = attention_scores[0, -1, -3:, -3:]
315+
316+
expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091])
291317
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

0 commit comments

Comments
 (0)