Skip to content

Commit 1fcf279

Browse files
authored
Fix mps tests on torch 2.0 (#2766)
1 parent 58bcf46 commit 1fcf279

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

tests/test_unet_2d_blocks.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,7 @@ def prepare_init_args_and_inputs_for_common(self):
255255
return init_dict, inputs_dict
256256

257257
def test_output(self):
258-
if torch_device == "mps":
259-
expected_slice = [0.4327, 0.5538, 0.3919, 0.5682, 0.2704, 0.1573, -0.8768, -0.4615, -0.4146]
260-
else:
261-
expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402]
258+
expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402]
262259
super().test_output(expected_slice)
263260

264261

@@ -336,8 +333,5 @@ def prepare_init_args_and_inputs_for_common(self):
336333
return init_dict, inputs_dict
337334

338335
def test_output(self):
339-
if torch_device == "mps":
340-
expected_slice = [-0.3669, -0.3387, 0.1029, -0.6564, 0.2728, -0.3233, 0.5977, -0.1784, 0.5482]
341-
else:
342-
expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568]
336+
expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568]
343337
super().test_output(expected_slice)

0 commit comments

Comments
 (0)