From 002b10913c24be2a81c867ce53932b909da0da5a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 21 Mar 2023 11:05:18 +0000 Subject: [PATCH] Fix mps tests on torch 2.0 --- tests/test_unet_2d_blocks.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_unet_2d_blocks.py b/tests/test_unet_2d_blocks.py index ea3390e8e6a9..e560240422ac 100644 --- a/tests/test_unet_2d_blocks.py +++ b/tests/test_unet_2d_blocks.py @@ -255,10 +255,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_output(self): - if torch_device == "mps": - expected_slice = [0.4327, 0.5538, 0.3919, 0.5682, 0.2704, 0.1573, -0.8768, -0.4615, -0.4146] - else: - expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402] + expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402] super().test_output(expected_slice) @@ -336,8 +333,5 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_output(self): - if torch_device == "mps": - expected_slice = [-0.3669, -0.3387, 0.1029, -0.6564, 0.2728, -0.3233, 0.5977, -0.1784, 0.5482] - else: - expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568] + expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568] super().test_output(expected_slice)