@@ -411,7 +411,9 @@ def test_spatial_transformer_cross_attention_dim(self):
411
411
412
412
assert attention_scores .shape == (1 , 64 , 64 , 64 )
413
413
output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
414
- expected_slice = torch .tensor ([0.0143 , - 0.6909 , - 2.1547 , - 1.8893 , 1.4097 , 0.1359 , - 0.2521 , - 1.3359 , 0.2598 ])
414
+ expected_slice = torch .tensor (
415
+ [0.0143 , - 0.6909 , - 2.1547 , - 1.8893 , 1.4097 , 0.1359 , - 0.2521 , - 1.3359 , 0.2598 ], device = torch_device
416
+ )
415
417
assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
416
418
417
419
def test_spatial_transformer_timestep (self ):
@@ -442,9 +444,11 @@ def test_spatial_transformer_timestep(self):
442
444
output_slice_1 = attention_scores_1 [0 , - 1 , - 3 :, - 3 :]
443
445
output_slice_2 = attention_scores_2 [0 , - 1 , - 3 :, - 3 :]
444
446
445
- expected_slice = torch .tensor ([- 0.3923 , - 1.0923 , - 1.7144 , - 1.5570 , 1.4154 , 0.1738 , - 0.1157 , - 1.2998 , - 0.1703 ])
447
+ expected_slice = torch .tensor (
448
+ [- 0.3923 , - 1.0923 , - 1.7144 , - 1.5570 , 1.4154 , 0.1738 , - 0.1157 , - 1.2998 , - 0.1703 ], device = torch_device
449
+ )
446
450
expected_slice_2 = torch .tensor (
447
- [- 0.4311 , - 1.1376 , - 1.7732 , - 1.5997 , 1.3450 , 0.0964 , - 0.1569 , - 1.3590 , - 0.2348 ]
451
+ [- 0.4311 , - 1.1376 , - 1.7732 , - 1.5997 , 1.3450 , 0.0964 , - 0.1569 , - 1.3590 , - 0.2348 ], device = torch_device
448
452
)
449
453
450
454
assert torch .allclose (output_slice_1 .flatten (), expected_slice , atol = 1e-3 )
0 commit comments