@@ -251,15 +251,15 @@ def test_spatial_transformer_default(self):
251251 torch .cuda .manual_seed_all (0 )
252252
253253 sample = torch .randn (1 , 32 , 64 , 64 ).to (torch_device )
254- spatialTransformerBlock = SpatialTransformer (
254+ spatial_transformer_block = SpatialTransformer (
255255 in_channels = 32 ,
256256 n_heads = 1 ,
257257 d_head = 32 ,
258258 dropout = 0.0 ,
259259 context_dim = None ,
260260 ).to (torch_device )
261261 with torch .no_grad ():
262- attention_scores = spatialTransformerBlock (sample )
262+ attention_scores = spatial_transformer_block (sample )
263263
264264 assert attention_scores .shape == (1 , 32 , 64 , 64 )
265265 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
@@ -272,20 +272,46 @@ def test_spatial_transformer_context_dim(self):
272272 if torch .cuda .is_available ():
273273 torch .cuda .manual_seed_all (0 )
274274
275- torch .manual_seed (0 )
276275 sample = torch .randn (1 , 64 , 64 , 64 ).to (torch_device )
277- spatialTransformerBlock = SpatialTransformer (
276+ spatial_transformer_block = SpatialTransformer (
278277 in_channels = 64 ,
279278 n_heads = 2 ,
280279 d_head = 32 ,
281280 dropout = 0.0 ,
282281 context_dim = 64 ,
283282 ).to (torch_device )
284283 with torch .no_grad ():
285- attention_scores = spatialTransformerBlock (sample )
284+ context = torch .randn (1 , 4 , 64 ).to (torch_device )
285+ attention_scores = spatial_transformer_block (sample , context )
286286
287287 assert attention_scores .shape == (1 , 64 , 64 , 64 )
288288 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
289289
290- expected_slice = torch .tensor ([- 0.0278 , - 0.7288 , - 2.2825 , - 2.0128 , 1.4513 , 0.2600 , - 0.2489 , - 1.4279 , 0.1277 ])
290+ expected_slice = torch .tensor ([- 0.2555 , - 0.8877 , - 2.4739 , - 2.2251 , 1.2714 , 0.0807 , - 0.4161 , - 1.6408 , - 0.0471 ])
291+ assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
292+
293+ def test_spatial_transformer_dropout (self ):
294+ torch .manual_seed (0 )
295+ if torch .cuda .is_available ():
296+ torch .cuda .manual_seed_all (0 )
297+
298+ sample = torch .randn (1 , 32 , 64 , 64 ).to (torch_device )
299+ spatial_transformer_block = (
300+ SpatialTransformer (
301+ in_channels = 32 ,
302+ n_heads = 2 ,
303+ d_head = 16 ,
304+ dropout = 0.3 ,
305+ context_dim = None ,
306+ )
307+ .to (torch_device )
308+ .eval ()
309+ )
310+ with torch .no_grad ():
311+ attention_scores = spatial_transformer_block (sample )
312+
313+ assert attention_scores .shape == (1 , 32 , 64 , 64 )
314+ output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
315+
316+ expected_slice = torch .tensor ([- 1.2448 , - 0.0190 , - 0.9471 , - 1.5140 , 0.7069 , - 1.0144 , - 2.1077 , 0.9099 , - 1.0091 ])
291317 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
0 commit comments