Skip to content

Commit 9b2a3ca

Browse files
committed
removed dropout test
1 parent 129333d commit 9b2a3ca

File tree

1 file changed

+0
-22
lines changed

1 file changed

+0
-22
lines changed

tests/test_layers_utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -289,25 +289,3 @@ def test_spatial_transformer_context_dim(self):
289289

290290
expected_slice = torch.tensor([-0.0278, -0.7288, -2.2825, -2.0128, 1.4513, 0.2600, -0.2489, -1.4279, 0.1277])
291291
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
292-
293-
def test_spatial_transformer_dropout(self):
294-
torch.manual_seed(0)
295-
if torch.cuda.is_available():
296-
torch.cuda.manual_seed_all(0)
297-
298-
sample = torch.randn(1, 32, 64, 64).to(torch_device)
299-
spatialTransformerBlock = SpatialTransformer(
300-
in_channels=32,
301-
n_heads=2,
302-
d_head=16,
303-
dropout=0.3,
304-
context_dim=None,
305-
).to(torch_device)
306-
with torch.no_grad():
307-
attention_scores = spatialTransformerBlock(sample)
308-
309-
assert attention_scores.shape == (1, 32, 64, 64)
310-
output_slice = attention_scores[0, -1, -3:, -3:]
311-
312-
expected_slice = torch.tensor([-1.4387, 0.0335, -0.9627, -1.4815, 0.6288, -1.0577, -2.1272, 0.8841, -1.0216])
313-
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

0 commit comments

Comments
 (0)