From 2e8fb0f188318acd66e1b734dd6d30be6decec20 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Apr 2023 10:10:49 +0200 Subject: [PATCH 1/2] fix slow tsets --- tests/test_layers_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 1f6e445f9d61..6fdce93599e4 100644 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -411,7 +411,7 @@ def test_spatial_transformer_cross_attention_dim(self): assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] - expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598]) + expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_spatial_transformer_timestep(self): @@ -442,9 +442,10 @@ def test_spatial_transformer_timestep(self): output_slice_1 = attention_scores_1[0, -1, -3:, -3:] output_slice_2 = attention_scores_2[0, -1, -3:, -3:] - expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703]) + expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device) expected_slice_2 = torch.tensor( - [-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348] + [-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], + device=torch_device ) assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3) From 8380e24bcd302ef049ee5f6734d899d45e92ddcf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Apr 2023 10:11:51 +0200 Subject: [PATCH 2/2] make style --- tests/test_layers_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 6fdce93599e4..db0d6c78d902 100644 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -411,7 +411,9 @@ def test_spatial_transformer_cross_attention_dim(self): assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] - expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device) + expected_slice = torch.tensor( + [0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device + ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_spatial_transformer_timestep(self): @@ -442,10 +444,11 @@ def test_spatial_transformer_timestep(self): output_slice_1 = attention_scores_1[0, -1, -3:, -3:] output_slice_2 = attention_scores_2[0, -1, -3:, -3:] - expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device) + expected_slice = torch.tensor( + [-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device + ) expected_slice_2 = torch.tensor( - [-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], - device=torch_device + [-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device ) assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3)