@@ -1201,11 +1201,30 @@ def test_forward_scriptability(self):
1201
1201
torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
1202
1202
1203
1203
1204
+ # NS: Remove me once backward is implemented for MPS
1205
+ def xfail_if_mps (x ):
1206
+ mps_xfail_param = pytest .param ("mps" , marks = (pytest .mark .needs_mps , pytest .mark .xfail ))
1207
+ new_pytestmark = []
1208
+ for mark in x .pytestmark :
1209
+ if isinstance (mark , pytest .Mark ) and mark .name == "parametrize" :
1210
+ if mark .args [0 ] == "device" :
1211
+ params = cpu_and_cuda () + (mps_xfail_param ,)
1212
+ new_pytestmark .append (pytest .mark .parametrize ("device" , params ))
1213
+ continue
1214
+ new_pytestmark .append (mark )
1215
+ x .__dict__ ["pytestmark" ] = new_pytestmark
1216
+ return x
1217
+
1218
+
1204
1219
optests .generate_opcheck_tests (
1205
1220
testcase = TestDeformConv ,
1206
1221
namespaces = ["torchvision" ],
1207
1222
failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1208
- additional_decorators = [],
1223
+ # Skip tests due to unimplemented backward
1224
+ additional_decorators = {
1225
+ "test_aot_dispatch_dynamic__test_forward" : [xfail_if_mps ],
1226
+ "test_autograd_registration__test_forward" : [xfail_if_mps ],
1227
+ },
1209
1228
test_utils = OPTESTS ,
1210
1229
)
1211
1230
0 commit comments