@@ -255,10 +255,7 @@ def prepare_init_args_and_inputs_for_common(self):
255
255
return init_dict , inputs_dict
256
256
257
257
def test_output (self ):
258
- if torch_device == "mps" :
259
- expected_slice = [0.4327 , 0.5538 , 0.3919 , 0.5682 , 0.2704 , 0.1573 , - 0.8768 , - 0.4615 , - 0.4146 ]
260
- else :
261
- expected_slice = [0.2645 , 0.1480 , 0.0909 , 0.8044 , - 0.9758 , - 0.9083 , 0.0994 , - 1.1453 , - 0.7402 ]
258
+ expected_slice = [0.2645 , 0.1480 , 0.0909 , 0.8044 , - 0.9758 , - 0.9083 , 0.0994 , - 1.1453 , - 0.7402 ]
262
259
super ().test_output (expected_slice )
263
260
264
261
@@ -336,8 +333,5 @@ def prepare_init_args_and_inputs_for_common(self):
336
333
return init_dict , inputs_dict
337
334
338
335
def test_output (self ):
339
- if torch_device == "mps" :
340
- expected_slice = [- 0.3669 , - 0.3387 , 0.1029 , - 0.6564 , 0.2728 , - 0.3233 , 0.5977 , - 0.1784 , 0.5482 ]
341
- else :
342
- expected_slice = [0.6738 , 0.4491 , 0.1055 , 1.0710 , 0.7316 , 0.3339 , 0.3352 , 0.1023 , 0.3568 ]
336
+ expected_slice = [0.6738 , 0.4491 , 0.1055 , 1.0710 , 0.7316 , 0.3339 , 0.3352 , 0.1023 , 0.3568 ]
343
337
super ().test_output (expected_slice )
0 commit comments