We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8fdbf2f commit d9c7715Copy full SHA for d9c7715
utils/__init__.py
@@ -79,6 +79,7 @@
79
parse_flag_from_env,
80
print_tensor_test,
81
require_torch_gpu,
82
+ skip_mps,
83
slow,
84
torch_all_close,
85
torch_device,
utils/testing_utils.py
@@ -163,6 +163,11 @@ def require_torch_gpu(test_case):
163
)
164
165
166
+def skip_mps(test_case):
167
+ """Decorator marking a test to skip if torch_device is 'mps'"""
168
+ return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
169
+
170
171
def require_flax(test_case):
172
"""
173
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
0 commit comments