diff --git a/tests/test_unet_2d_blocks.py b/tests/test_unet_2d_blocks.py index ea3390e8e6a9..fd96d6ffdb9d 100644 --- a/tests/test_unet_2d_blocks.py +++ b/tests/test_unet_2d_blocks.py @@ -254,11 +254,9 @@ def prepare_init_args_and_inputs_for_common(self): init_dict["cross_attention_dim"] = 32 return init_dict, inputs_dict + @unittest.skipIf(torch_device == "mps", "MPS result is not consistent") def test_output(self): - if torch_device == "mps": - expected_slice = [0.4327, 0.5538, 0.3919, 0.5682, 0.2704, 0.1573, -0.8768, -0.4615, -0.4146] - else: - expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402] + expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402] super().test_output(expected_slice) @@ -335,9 +333,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf(torch_device == "mps", "MPS result is not consistent") def test_output(self): - if torch_device == "mps": - expected_slice = [-0.3669, -0.3387, 0.1029, -0.6564, 0.2728, -0.3233, 0.5977, -0.1784, 0.5482] - else: - expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568] + expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568] super().test_output(expected_slice)