Skip to content

Commit 836d02a

Browse files
carmoccaBorda
authored andcommitted
Add RunIf
1 parent 4aa9be2 commit 836d02a

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed

tests/helpers/runif.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import sys
16+
from distutils.version import LooseVersion
17+
from typing import Optional
18+
19+
import pytest
20+
import torch
21+
from pkg_resources import get_distribution
22+
23+
from pytorch_lightning.utilities import (
24+
_APEX_AVAILABLE,
25+
_DEEPSPEED_AVAILABLE,
26+
_FAIRSCALE_AVAILABLE,
27+
_FAIRSCALE_PIPE_AVAILABLE,
28+
_HOROVOD_AVAILABLE,
29+
_NATIVE_AMP_AVAILABLE,
30+
_RPC_AVAILABLE,
31+
_TORCH_QUANTIZE_AVAILABLE,
32+
_TPU_AVAILABLE,
33+
)
34+
35+
try:
36+
from horovod.common.util import nccl_built
37+
nccl_built()
38+
except (ImportError, ModuleNotFoundError, AttributeError):
39+
_HOROVOD_NCCL_AVAILABLE = False
40+
finally:
41+
_HOROVOD_NCCL_AVAILABLE = True
42+
43+
44+
class RunIf:
45+
"""
46+
RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
47+
48+
@RunIf(min_torch="0.0")
49+
@pytest.mark.parametrize("arg1", [1, 2.0])
50+
def test_wrapper(arg1):
51+
assert arg1 > 0.0
52+
"""
53+
54+
def __new__(
55+
self,
56+
*args,
57+
min_gpus: int = 0,
58+
min_torch: Optional[str] = None,
59+
max_torch: Optional[str] = None,
60+
min_python: Optional[str] = None,
61+
quantization: bool = False,
62+
amp_apex: bool = False,
63+
amp_native: bool = False,
64+
tpu: bool = False,
65+
horovod: bool = False,
66+
horovod_nccl: bool = False,
67+
skip_windows: bool = False,
68+
special: bool = False,
69+
rpc: bool = False,
70+
fairscale: bool = False,
71+
fairscale_pipe: bool = False,
72+
deepspeed: bool = False,
73+
**kwargs
74+
):
75+
"""
76+
Args:
77+
args: native pytest.mark.skipif arguments
78+
min_gpus: min number of gpus required to run test
79+
min_torch: minimum pytorch version to run test
80+
max_torch: maximum pytorch version to run test
81+
min_python: minimum python version required to run test
82+
quantization: if `torch.quantization` package is required to run test
83+
amp_apex: NVIDIA Apex is installed
84+
amp_native: if native PyTorch native AMP is supported
85+
tpu: if TPU is available
86+
horovod: if Horovod is installed
87+
horovod_nccl: if Horovod is installed with NCCL support
88+
skip_windows: skip test for Windows platform (typically fo some limited torch functionality)
89+
special: running in special mode, outside pytest suit
90+
rpc: requires Remote Procedure Call (RPC)
91+
fairscale: if `fairscale` module is required to run the test
92+
deepspeed: if `deepspeed` module is required to run the test
93+
kwargs: native pytest.mark.skipif keyword arguments
94+
"""
95+
conditions = []
96+
reasons = []
97+
98+
if min_gpus:
99+
conditions.append(torch.cuda.device_count() < min_gpus)
100+
reasons.append(f"GPUs>={min_gpus}")
101+
102+
if min_torch:
103+
torch_version = LooseVersion(get_distribution("torch").version)
104+
conditions.append(torch_version < LooseVersion(min_torch))
105+
reasons.append(f"torch>={min_torch}")
106+
107+
if max_torch:
108+
torch_version = LooseVersion(get_distribution("torch").version)
109+
conditions.append(torch_version >= LooseVersion(max_torch))
110+
reasons.append(f"torch<{max_torch}")
111+
112+
if min_python:
113+
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
114+
conditions.append(py_version < LooseVersion(min_python))
115+
reasons.append(f"python>={min_python}")
116+
117+
if quantization:
118+
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
119+
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
120+
reasons.append("PyTorch quantization")
121+
122+
if amp_native:
123+
conditions.append(not _NATIVE_AMP_AVAILABLE)
124+
reasons.append("native AMP")
125+
126+
if amp_apex:
127+
conditions.append(not _APEX_AVAILABLE)
128+
reasons.append("NVIDIA Apex")
129+
130+
if skip_windows:
131+
conditions.append(sys.platform == "win32")
132+
reasons.append("unimplemented on Windows")
133+
134+
if tpu:
135+
conditions.append(not _TPU_AVAILABLE)
136+
reasons.append("TPU")
137+
138+
if horovod:
139+
conditions.append(not _HOROVOD_AVAILABLE)
140+
reasons.append("Horovod")
141+
142+
if horovod_nccl:
143+
conditions.append(not _HOROVOD_NCCL_AVAILABLE)
144+
reasons.append("Horovod with NCCL")
145+
146+
if special:
147+
env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", '0')
148+
conditions.append(env_flag != '1')
149+
reasons.append("Special execution")
150+
151+
if rpc:
152+
conditions.append(not _RPC_AVAILABLE)
153+
reasons.append("RPC")
154+
155+
if fairscale:
156+
conditions.append(not _FAIRSCALE_AVAILABLE)
157+
reasons.append("Fairscale")
158+
159+
if fairscale_pipe:
160+
conditions.append(not _FAIRSCALE_PIPE_AVAILABLE)
161+
reasons.append("Fairscale Pipe")
162+
163+
if deepspeed:
164+
conditions.append(not _DEEPSPEED_AVAILABLE)
165+
reasons.append("Deepspeed")
166+
167+
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
168+
return pytest.mark.skipif(
169+
*args,
170+
condition=any(conditions),
171+
reason=f"Requires: [{' + '.join(reasons)}]",
172+
**kwargs,
173+
)
174+
175+
176+
@RunIf(min_torch="99")
177+
def test_always_skip():
178+
exit(1)
179+
180+
181+
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
182+
@RunIf(min_torch="0.0")
183+
def test_wrapper(arg1: float):
184+
assert arg1 > 0.0

0 commit comments

Comments
 (0)