43
43
traverse the entire AST.
44
44
"""
45
45
46
+ import sys
46
47
from typing import (
47
48
List , Dict , Set , Tuple , cast , Any , overload , TypeVar , Union , Optional , Callable
48
49
)
64
65
SetComprehension , DictionaryComprehension , TYPE_ALIAS , TypeAliasExpr ,
65
66
YieldExpr , ExecStmt , Argument , BackquoteExpr , ImportBase , AwaitExpr ,
66
67
IntExpr , FloatExpr , UnicodeExpr ,
67
- COVARIANT , CONTRAVARIANT , INVARIANT , UNBOUND_IMPORTED ,
68
+ COVARIANT , CONTRAVARIANT , INVARIANT , UNBOUND_IMPORTED , LITERAL_YES ,
68
69
)
69
70
from mypy .visitor import NodeVisitor
70
71
from mypy .traverser import TraverserVisitor
@@ -2781,17 +2782,22 @@ def infer_if_condition_value(expr: Node, pyversion: Tuple[int, int]) -> int:
2781
2782
if alias .op == 'not' :
2782
2783
expr = alias .expr
2783
2784
negated = True
2785
+ result = TRUTH_VALUE_UNKNOWN
2784
2786
if isinstance (expr , NameExpr ):
2785
2787
name = expr .name
2786
2788
elif isinstance (expr , MemberExpr ):
2787
2789
name = expr .name
2788
- result = TRUTH_VALUE_UNKNOWN
2789
- if name == 'PY2' :
2790
- result = ALWAYS_TRUE if pyversion [0 ] == 2 else ALWAYS_FALSE
2791
- elif name == 'PY3' :
2792
- result = ALWAYS_TRUE if pyversion [0 ] == 3 else ALWAYS_FALSE
2793
- elif name == 'MYPY' :
2794
- result = ALWAYS_TRUE
2790
+ else :
2791
+ result = consider_sys_version_info (expr , pyversion )
2792
+ if result == TRUTH_VALUE_UNKNOWN :
2793
+ result = consider_sys_platform (expr , sys .platform )
2794
+ if result == TRUTH_VALUE_UNKNOWN :
2795
+ if name == 'PY2' :
2796
+ result = ALWAYS_TRUE if pyversion [0 ] == 2 else ALWAYS_FALSE
2797
+ elif name == 'PY3' :
2798
+ result = ALWAYS_TRUE if pyversion [0 ] == 3 else ALWAYS_FALSE
2799
+ elif name == 'MYPY' :
2800
+ result = ALWAYS_TRUE
2795
2801
if negated :
2796
2802
if result == ALWAYS_TRUE :
2797
2803
result = ALWAYS_FALSE
@@ -2800,6 +2806,144 @@ def infer_if_condition_value(expr: Node, pyversion: Tuple[int, int]) -> int:
2800
2806
return result
2801
2807
2802
2808
2809
+ def consider_sys_version_info (expr : Node , pyversion : Tuple [int , ...]) -> int :
2810
+ """Consider whether expr is a comparison involving sys.version_info.
2811
+
2812
+ Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
2813
+ """
2814
+ # Cases supported:
2815
+ # - sys.version_info[<int>] <compare_op> <int>
2816
+ # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
2817
+ # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
2818
+ # (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
2819
+ if not isinstance (expr , ComparisonExpr ):
2820
+ return TRUTH_VALUE_UNKNOWN
2821
+ # Let's not yet support chained comparisons.
2822
+ if len (expr .operators ) > 1 :
2823
+ return TRUTH_VALUE_UNKNOWN
2824
+ op = expr .operators [0 ]
2825
+ if op not in ('==' , '!=' , '<=' , '>=' , '<' , '>' ):
2826
+ return TRUTH_VALUE_UNKNOWN
2827
+ thing = contains_int_or_tuple_of_ints (expr .operands [1 ])
2828
+ if thing is None :
2829
+ return TRUTH_VALUE_UNKNOWN
2830
+ index = contains_sys_version_info (expr .operands [0 ])
2831
+ if isinstance (index , int ) and isinstance (thing , int ):
2832
+ # sys.version_info[i] <compare_op> k
2833
+ if 0 <= index <= 1 :
2834
+ return fixed_comparison (pyversion [index ], op , thing )
2835
+ else :
2836
+ return TRUTH_VALUE_UNKNOWN
2837
+ elif isinstance (index , tuple ) and isinstance (thing , tuple ):
2838
+ # Why doesn't mypy see that index can't be None here?
2839
+ lo , hi = cast (tuple , index )
2840
+ if lo is None :
2841
+ lo = 0
2842
+ if hi is None :
2843
+ hi = 2
2844
+ if 0 <= lo < hi <= 2 :
2845
+ val = pyversion [lo :hi ]
2846
+ if len (val ) == len (thing ) or len (val ) > len (thing ) and op not in ('==' , '!=' ):
2847
+ return fixed_comparison (val , op , thing )
2848
+ return TRUTH_VALUE_UNKNOWN
2849
+
2850
+
2851
+ def consider_sys_platform (expr : Node , platform : str ) -> int :
2852
+ """Consider whether expr is a comparison involving sys.platform.
2853
+
2854
+ Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
2855
+ """
2856
+ # Cases supported:
2857
+ # - sys.platform == 'posix'
2858
+ # - sys.platform != 'win32'
2859
+ # TODO: Maybe support e.g.:
2860
+ # - sys.platform.startswith('win')
2861
+ if not isinstance (expr , ComparisonExpr ):
2862
+ return TRUTH_VALUE_UNKNOWN
2863
+ # Let's not yet support chained comparisons.
2864
+ if len (expr .operators ) > 1 :
2865
+ return TRUTH_VALUE_UNKNOWN
2866
+ op = expr .operators [0 ]
2867
+ if op not in ('==' , '!=' ):
2868
+ return TRUTH_VALUE_UNKNOWN
2869
+ if not is_sys_attr (expr .operands [0 ], 'platform' ):
2870
+ return TRUTH_VALUE_UNKNOWN
2871
+ right = expr .operands [1 ]
2872
+ if not isinstance (right , (StrExpr , UnicodeExpr )):
2873
+ return TRUTH_VALUE_UNKNOWN
2874
+ return fixed_comparison (platform , op , right .value )
2875
+
2876
+
2877
+ Targ = TypeVar ('Targ' , int , str , Tuple [int , ...])
2878
+
2879
+
2880
+ def fixed_comparison (left : Targ , op : str , right : Targ ) -> int :
2881
+ rmap = {False : ALWAYS_FALSE , True : ALWAYS_TRUE }
2882
+ if op == '==' :
2883
+ return rmap [left == right ]
2884
+ if op == '!=' :
2885
+ return rmap [left != right ]
2886
+ if op == '<=' :
2887
+ return rmap [left <= right ]
2888
+ if op == '>=' :
2889
+ return rmap [left >= right ]
2890
+ if op == '<' :
2891
+ return rmap [left < right ]
2892
+ if op == '>' :
2893
+ return rmap [left > right ]
2894
+ return TRUTH_VALUE_UNKNOWN
2895
+
2896
+
2897
+ def contains_int_or_tuple_of_ints (expr : Node ) -> Union [None , int , Tuple [int ], Tuple [int , ...]]:
2898
+ if isinstance (expr , IntExpr ):
2899
+ return expr .value
2900
+ if isinstance (expr , TupleExpr ):
2901
+ if expr .literal == LITERAL_YES :
2902
+ thing = []
2903
+ for x in expr .items :
2904
+ if not isinstance (x , IntExpr ):
2905
+ return None
2906
+ thing .append (x .value )
2907
+ return tuple (thing )
2908
+ return None
2909
+
2910
+
2911
+ def contains_sys_version_info (expr : Node ) -> Union [None , int , Tuple [Optional [int ], Optional [int ]]]:
2912
+ if is_sys_attr (expr , 'version_info' ):
2913
+ return (None , None ) # Same as sys.version_info[:]
2914
+ if isinstance (expr , IndexExpr ) and is_sys_attr (expr .base , 'version_info' ):
2915
+ index = expr .index
2916
+ if isinstance (index , IntExpr ):
2917
+ return index .value
2918
+ if isinstance (index , SliceExpr ):
2919
+ if index .stride is not None :
2920
+ if not isinstance (index .stride , IntExpr ) or index .stride .value != 1 :
2921
+ return None
2922
+ begin = end = None
2923
+ if index .begin_index is not None :
2924
+ if not isinstance (index .begin_index , IntExpr ):
2925
+ return None
2926
+ begin = index .begin_index .value
2927
+ if index .end_index is not None :
2928
+ if not isinstance (index .end_index , IntExpr ):
2929
+ return None
2930
+ end = index .end_index .value
2931
+ return (begin , end )
2932
+ return None
2933
+
2934
+
2935
+ def is_sys_attr (expr : Node , name : str ) -> bool :
2936
+ # TODO: This currently doesn't work with code like this:
2937
+ # - import sys as _sys
2938
+ # - from sys import version_info
2939
+ if isinstance (expr , MemberExpr ) and expr .name == name :
2940
+ if isinstance (expr .expr , NameExpr ) and expr .expr .name == 'sys' :
2941
+ # TODO: Guard against a local named sys, etc.
2942
+ # (Though later passes will still do most checking.)
2943
+ return True
2944
+ return False
2945
+
2946
+
2803
2947
def mark_block_unreachable (block : Block ) -> None :
2804
2948
block .is_unreachable = True
2805
2949
block .accept (MarkImportsUnreachableVisitor ())
0 commit comments