Skip to content

Commit 0b82463

Browse files
authored
Implement most common sys.version_info and sys.platform checks (#1942)
Addresses most but not all of #698. (The first two bullets of my plan.)
1 parent f129ee7 commit 0b82463

File tree

5 files changed

+409
-16
lines changed

5 files changed

+409
-16
lines changed

mypy/semanal.py

Lines changed: 152 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
traverse the entire AST.
4444
"""
4545

46+
import sys
4647
from typing import (
4748
List, Dict, Set, Tuple, cast, Any, overload, TypeVar, Union, Optional, Callable
4849
)
@@ -64,7 +65,7 @@
6465
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
6566
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
6667
IntExpr, FloatExpr, UnicodeExpr,
67-
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED,
68+
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES,
6869
)
6970
from mypy.visitor import NodeVisitor
7071
from mypy.traverser import TraverserVisitor
@@ -2781,17 +2782,22 @@ def infer_if_condition_value(expr: Node, pyversion: Tuple[int, int]) -> int:
27812782
if alias.op == 'not':
27822783
expr = alias.expr
27832784
negated = True
2785+
result = TRUTH_VALUE_UNKNOWN
27842786
if isinstance(expr, NameExpr):
27852787
name = expr.name
27862788
elif isinstance(expr, MemberExpr):
27872789
name = expr.name
2788-
result = TRUTH_VALUE_UNKNOWN
2789-
if name == 'PY2':
2790-
result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
2791-
elif name == 'PY3':
2792-
result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
2793-
elif name == 'MYPY':
2794-
result = ALWAYS_TRUE
2790+
else:
2791+
result = consider_sys_version_info(expr, pyversion)
2792+
if result == TRUTH_VALUE_UNKNOWN:
2793+
result = consider_sys_platform(expr, sys.platform)
2794+
if result == TRUTH_VALUE_UNKNOWN:
2795+
if name == 'PY2':
2796+
result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
2797+
elif name == 'PY3':
2798+
result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
2799+
elif name == 'MYPY':
2800+
result = ALWAYS_TRUE
27952801
if negated:
27962802
if result == ALWAYS_TRUE:
27972803
result = ALWAYS_FALSE
@@ -2800,6 +2806,144 @@ def infer_if_condition_value(expr: Node, pyversion: Tuple[int, int]) -> int:
28002806
return result
28012807

28022808

2809+
def consider_sys_version_info(expr: Node, pyversion: Tuple[int, ...]) -> int:
2810+
"""Consider whether expr is a comparison involving sys.version_info.
2811+
2812+
Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
2813+
"""
2814+
# Cases supported:
2815+
# - sys.version_info[<int>] <compare_op> <int>
2816+
# - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
2817+
# - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
2818+
# (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
2819+
if not isinstance(expr, ComparisonExpr):
2820+
return TRUTH_VALUE_UNKNOWN
2821+
# Let's not yet support chained comparisons.
2822+
if len(expr.operators) > 1:
2823+
return TRUTH_VALUE_UNKNOWN
2824+
op = expr.operators[0]
2825+
if op not in ('==', '!=', '<=', '>=', '<', '>'):
2826+
return TRUTH_VALUE_UNKNOWN
2827+
thing = contains_int_or_tuple_of_ints(expr.operands[1])
2828+
if thing is None:
2829+
return TRUTH_VALUE_UNKNOWN
2830+
index = contains_sys_version_info(expr.operands[0])
2831+
if isinstance(index, int) and isinstance(thing, int):
2832+
# sys.version_info[i] <compare_op> k
2833+
if 0 <= index <= 1:
2834+
return fixed_comparison(pyversion[index], op, thing)
2835+
else:
2836+
return TRUTH_VALUE_UNKNOWN
2837+
elif isinstance(index, tuple) and isinstance(thing, tuple):
2838+
# Why doesn't mypy see that index can't be None here?
2839+
lo, hi = cast(tuple, index)
2840+
if lo is None:
2841+
lo = 0
2842+
if hi is None:
2843+
hi = 2
2844+
if 0 <= lo < hi <= 2:
2845+
val = pyversion[lo:hi]
2846+
if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='):
2847+
return fixed_comparison(val, op, thing)
2848+
return TRUTH_VALUE_UNKNOWN
2849+
2850+
2851+
def consider_sys_platform(expr: Node, platform: str) -> int:
2852+
"""Consider whether expr is a comparison involving sys.platform.
2853+
2854+
Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
2855+
"""
2856+
# Cases supported:
2857+
# - sys.platform == 'posix'
2858+
# - sys.platform != 'win32'
2859+
# TODO: Maybe support e.g.:
2860+
# - sys.platform.startswith('win')
2861+
if not isinstance(expr, ComparisonExpr):
2862+
return TRUTH_VALUE_UNKNOWN
2863+
# Let's not yet support chained comparisons.
2864+
if len(expr.operators) > 1:
2865+
return TRUTH_VALUE_UNKNOWN
2866+
op = expr.operators[0]
2867+
if op not in ('==', '!='):
2868+
return TRUTH_VALUE_UNKNOWN
2869+
if not is_sys_attr(expr.operands[0], 'platform'):
2870+
return TRUTH_VALUE_UNKNOWN
2871+
right = expr.operands[1]
2872+
if not isinstance(right, (StrExpr, UnicodeExpr)):
2873+
return TRUTH_VALUE_UNKNOWN
2874+
return fixed_comparison(platform, op, right.value)
2875+
2876+
2877+
Targ = TypeVar('Targ', int, str, Tuple[int, ...])
2878+
2879+
2880+
def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
2881+
rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
2882+
if op == '==':
2883+
return rmap[left == right]
2884+
if op == '!=':
2885+
return rmap[left != right]
2886+
if op == '<=':
2887+
return rmap[left <= right]
2888+
if op == '>=':
2889+
return rmap[left >= right]
2890+
if op == '<':
2891+
return rmap[left < right]
2892+
if op == '>':
2893+
return rmap[left > right]
2894+
return TRUTH_VALUE_UNKNOWN
2895+
2896+
2897+
def contains_int_or_tuple_of_ints(expr: Node) -> Union[None, int, Tuple[int], Tuple[int, ...]]:
2898+
if isinstance(expr, IntExpr):
2899+
return expr.value
2900+
if isinstance(expr, TupleExpr):
2901+
if expr.literal == LITERAL_YES:
2902+
thing = []
2903+
for x in expr.items:
2904+
if not isinstance(x, IntExpr):
2905+
return None
2906+
thing.append(x.value)
2907+
return tuple(thing)
2908+
return None
2909+
2910+
2911+
def contains_sys_version_info(expr: Node) -> Union[None, int, Tuple[Optional[int], Optional[int]]]:
2912+
if is_sys_attr(expr, 'version_info'):
2913+
return (None, None) # Same as sys.version_info[:]
2914+
if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'):
2915+
index = expr.index
2916+
if isinstance(index, IntExpr):
2917+
return index.value
2918+
if isinstance(index, SliceExpr):
2919+
if index.stride is not None:
2920+
if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
2921+
return None
2922+
begin = end = None
2923+
if index.begin_index is not None:
2924+
if not isinstance(index.begin_index, IntExpr):
2925+
return None
2926+
begin = index.begin_index.value
2927+
if index.end_index is not None:
2928+
if not isinstance(index.end_index, IntExpr):
2929+
return None
2930+
end = index.end_index.value
2931+
return (begin, end)
2932+
return None
2933+
2934+
2935+
def is_sys_attr(expr: Node, name: str) -> bool:
2936+
# TODO: This currently doesn't work with code like this:
2937+
# - import sys as _sys
2938+
# - from sys import version_info
2939+
if isinstance(expr, MemberExpr) and expr.name == name:
2940+
if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys':
2941+
# TODO: Guard against a local named sys, etc.
2942+
# (Though later passes will still do most checking.)
2943+
return True
2944+
return False
2945+
2946+
28032947
def mark_block_unreachable(block: Block) -> None:
28042948
block.is_unreachable = True
28052949
block.accept(MarkImportsUnreachableVisitor())

test-data/samples/crawl.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121

2222
ARGS = argparse.ArgumentParser(description="Web crawler")
23-
ARGS.add_argument(
24-
'--iocp', action='store_true', dest='iocp',
25-
default=False, help='Use IOCP event loop (Windows only)')
23+
if sys.platform == 'win32':
24+
ARGS.add_argument(
25+
'--iocp', action='store_true', dest='iocp',
26+
default=False, help='Use IOCP event loop (Windows only)')
2627
ARGS.add_argument(
2728
'--select', action='store_true', dest='select',
2829
default=False, help='Use Select event loop instead of default')
@@ -824,9 +825,12 @@ def main() -> None:
824825
log = Logger(args.level)
825826

826827
if args.iocp:
827-
from asyncio import ProactorEventLoop
828-
loop = ProactorEventLoop() # type: ignore
829-
asyncio.set_event_loop(loop)
828+
if sys.platform == 'win32':
829+
from asyncio import ProactorEventLoop
830+
loop = ProactorEventLoop() # type: ignore
831+
asyncio.set_event_loop(loop)
832+
else:
833+
assert False
830834
elif args.select:
831835
loop = asyncio.SelectorEventLoop() # type: ignore
832836
asyncio.set_event_loop(loop)

0 commit comments

Comments
 (0)