Skip to content

Commit 0c72006

Browse files
fix slow tsets (#3066)
* fix slow tsets * make style
1 parent a89a14f commit 0c72006

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tests/test_layers_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ def test_spatial_transformer_cross_attention_dim(self):
411411

412412
assert attention_scores.shape == (1, 64, 64, 64)
413413
output_slice = attention_scores[0, -1, -3:, -3:]
414-
expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598])
414+
expected_slice = torch.tensor(
415+
[0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device
416+
)
415417
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
416418

417419
def test_spatial_transformer_timestep(self):
@@ -442,9 +444,11 @@ def test_spatial_transformer_timestep(self):
442444
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
443445
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
444446

445-
expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703])
447+
expected_slice = torch.tensor(
448+
[-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device
449+
)
446450
expected_slice_2 = torch.tensor(
447-
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348]
451+
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device
448452
)
449453

450454
assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3)

0 commit comments

Comments
 (0)