Skip to content

feat: reachability recognizes platform.system() #10686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
65 changes: 43 additions & 22 deletions mypy/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def infer_condition_value(expr: Expression, options: Options) -> int:
else:
result = consider_sys_version_info(expr, pyversion)
if result == TRUTH_VALUE_UNKNOWN:
result = consider_sys_platform(expr, options.platform)
result = consider_platform(expr, options.platform)
if result == TRUTH_VALUE_UNKNOWN:
if name == 'PY2':
result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE
Expand Down Expand Up @@ -162,38 +162,59 @@ def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> i
return TRUTH_VALUE_UNKNOWN


def consider_sys_platform(expr: Expression, platform: str) -> int:
"""Consider whether expr is a comparison involving sys.platform.
def consider_platform(expr: Expression, sys_platform: str) -> int:
"""Consider whether expr is a comparison involving sys.platform or platform.system().

Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.

Cases supported:
- sys.platform == 'posix'
- sys.platform != 'win32'
- sys.platform.startswith('win')
- platform.system() == 'Linux'
- platform.system() != 'Windows'
- ...
"""
# Cases supported:
# - sys.platform == 'posix'
# - sys.platform != 'win32'
# - sys.platform.startswith('win')
MAPPING = {
'darwin': 'Darwin',
'linux': 'Linux',
'win32': 'Windows',
# 'aix': N/A,
# 'cygwin': N/A,
# 'freebsd': N/A,
# ... : N/A
}

if isinstance(expr, ComparisonExpr):
# Let's not yet support chained comparisons.
if len(expr.operators) > 1:
return TRUTH_VALUE_UNKNOWN

op = expr.operators[0]
if op not in ('==', '!='):
return TRUTH_VALUE_UNKNOWN
if not is_sys_attr(expr.operands[0], 'platform'):
return TRUTH_VALUE_UNKNOWN

right = expr.operands[1]
if not isinstance(right, (StrExpr, UnicodeExpr)):
return TRUTH_VALUE_UNKNOWN
return fixed_comparison(platform, op, right.value)

if expr_matches(expr.operands[0], 'sys', 'platform'):
return fixed_comparison(sys_platform, op, right.value)
if isinstance(expr.operands[0], CallExpr) \
and expr_matches(expr.operands[0].callee, 'platform', 'system') \
and sys_platform in MAPPING:
return fixed_comparison(MAPPING[sys_platform], op, right.value)
return TRUTH_VALUE_UNKNOWN
elif isinstance(expr, CallExpr):
if not isinstance(expr.callee, MemberExpr):
return TRUTH_VALUE_UNKNOWN
if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)):
return TRUTH_VALUE_UNKNOWN
if not is_sys_attr(expr.callee.expr, 'platform'):
if not expr_matches(expr.callee.expr, 'sys', 'platform'):
return TRUTH_VALUE_UNKNOWN
if expr.callee.name != 'startswith':
return TRUTH_VALUE_UNKNOWN
if platform.startswith(expr.args[0].value):
if sys_platform.startswith(expr.args[0].value):
return ALWAYS_TRUE
else:
return ALWAYS_FALSE
Expand Down Expand Up @@ -238,9 +259,9 @@ def contains_int_or_tuple_of_ints(expr: Expression

def contains_sys_version_info(expr: Expression
) -> Union[None, int, Tuple[Optional[int], Optional[int]]]:
if is_sys_attr(expr, 'version_info'):
if expr_matches(expr, 'sys', 'version_info'):
return (None, None) # Same as sys.version_info[:]
if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'):
if isinstance(expr, IndexExpr) and expr_matches(expr.base, 'sys', 'version_info'):
index = expr.index
if isinstance(index, IntExpr):
return index.value
Expand All @@ -261,15 +282,15 @@ def contains_sys_version_info(expr: Expression
return None


def is_sys_attr(expr: Expression, name: str) -> bool:
def expr_matches(expr: Expression, package: str, attr: str) -> bool:
# TODO: This currently doesn't work with code like this:
# - import sys as _sys
# - from sys import version_info
if isinstance(expr, MemberExpr) and expr.name == name:
if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys':
# TODO: Guard against a local named sys, etc.
# (Though later passes will still do most checking.)
return True
# - import package as _package
# - from package import version_info
if isinstance(expr, MemberExpr) and expr.name == attr \
and isinstance(expr.expr, NameExpr) and expr.expr.name == package:
# TODO: Guard against a local variable hiding package, etc.
# (Though later passes will still do most checking)
return True
return False


Expand Down
Loading