@@ -240,7 +240,9 @@ def test_attention_block_default(self):
240240 assert attention_scores .shape == (1 , 32 , 64 , 64 )
241241 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
242242
243- expected_slice = torch .tensor ([- 1.4975 , - 0.0038 , - 0.7847 , - 1.4567 , 1.1220 , - 0.8962 , - 1.7394 , 1.1319 , - 0.5427 ])
243+ expected_slice = torch .tensor (
244+ [- 1.4975 , - 0.0038 , - 0.7847 , - 1.4567 , 1.1220 , - 0.8962 , - 1.7394 , 1.1319 , - 0.5427 ], device = torch_device
245+ )
244246 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
245247
246248
@@ -264,7 +266,9 @@ def test_spatial_transformer_default(self):
264266 assert attention_scores .shape == (1 , 32 , 64 , 64 )
265267 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
266268
267- expected_slice = torch .tensor ([- 1.2447 , - 0.0137 , - 0.9559 , - 1.5223 , 0.6991 , - 1.0126 , - 2.0974 , 0.8921 , - 1.0201 ])
269+ expected_slice = torch .tensor (
270+ [- 1.2447 , - 0.0137 , - 0.9559 , - 1.5223 , 0.6991 , - 1.0126 , - 2.0974 , 0.8921 , - 1.0201 ], device = torch_device
271+ )
268272 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
269273
270274 def test_spatial_transformer_context_dim (self ):
@@ -287,7 +291,9 @@ def test_spatial_transformer_context_dim(self):
287291 assert attention_scores .shape == (1 , 64 , 64 , 64 )
288292 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
289293
290- expected_slice = torch .tensor ([- 0.2555 , - 0.8877 , - 2.4739 , - 2.2251 , 1.2714 , 0.0807 , - 0.4161 , - 1.6408 , - 0.0471 ])
294+ expected_slice = torch .tensor (
295+ [- 0.2555 , - 0.8877 , - 2.4739 , - 2.2251 , 1.2714 , 0.0807 , - 0.4161 , - 1.6408 , - 0.0471 ], device = torch_device
296+ )
291297 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
292298
293299 def test_spatial_transformer_dropout (self ):
@@ -313,5 +319,7 @@ def test_spatial_transformer_dropout(self):
313319 assert attention_scores .shape == (1 , 32 , 64 , 64 )
314320 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
315321
316- expected_slice = torch .tensor ([- 1.2448 , - 0.0190 , - 0.9471 , - 1.5140 , 0.7069 , - 1.0144 , - 2.1077 , 0.9099 , - 1.0091 ])
322+ expected_slice = torch .tensor (
323+ [- 1.2448 , - 0.0190 , - 0.9471 , - 1.5140 , 0.7069 , - 1.0144 , - 2.1077 , 0.9099 , - 1.0091 ], device = torch_device
324+ )
317325 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
0 commit comments