Skip to content

Commit d9c7715

Browse files
[Tests] Add MPS skip decorator (huggingface#2362)
* finish * Apply suggestions from code review * fix indent and import error in test_stable_diffusion_depth --------- Co-authored-by: William Berman <[email protected]>
1 parent 8fdbf2f commit d9c7715

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
parse_flag_from_env,
8080
print_tensor_test,
8181
require_torch_gpu,
82+
skip_mps,
8283
slow,
8384
torch_all_close,
8485
torch_device,

utils/testing_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def require_torch_gpu(test_case):
163163
)
164164

165165

166+
def skip_mps(test_case):
167+
"""Decorator marking a test to skip if torch_device is 'mps'"""
168+
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
169+
170+
166171
def require_flax(test_case):
167172
"""
168173
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed

0 commit comments

Comments
 (0)