@@ -3788,67 +3788,109 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
3788
3788
vartype = type_map [expr ]
3789
3789
return self .conditional_callable_type_map (expr , vartype )
3790
3790
elif isinstance (node , ComparisonExpr ):
3791
- operand_types = [coerce_to_literal (type_map [expr ])
3792
- for expr in node .operands if expr in type_map ]
3793
-
3794
- is_not = node .operators == ['is not' ]
3795
- if (is_not or node .operators == ['is' ]) and len (operand_types ) == len (node .operands ):
3796
- if_vars = {} # type: TypeMap
3797
- else_vars = {} # type: TypeMap
3798
-
3799
- for i , expr in enumerate (node .operands ):
3800
- var_type = operand_types [i ]
3801
- other_type = operand_types [1 - i ]
3802
-
3803
- if literal (expr ) == LITERAL_TYPE and is_singleton_type (other_type ):
3804
- # This should only be true at most once: there should be
3805
- # exactly two elements in node.operands and if the 'other type' is
3806
- # a singleton type, it by definition does not need to be narrowed:
3807
- # it already has the most precise type possible so does not need to
3808
- # be narrowed/included in the output map.
3809
- #
3810
- # TODO: Generalize this to handle the case where 'other_type' is
3811
- # a union of singleton types.
3812
-
3813
- if isinstance (other_type , LiteralType ) and other_type .is_enum_literal ():
3814
- fallback_name = other_type .fallback .type .fullname
3815
- var_type = try_expanding_enum_to_union (var_type , fallback_name )
3816
-
3817
- target_type = [TypeRange (other_type , is_upper_bound = False )]
3818
- if_vars , else_vars = conditional_type_map (expr , var_type , target_type )
3819
- break
3791
+ # Step 1: Obtain the types of each operand and whether or not we can
3792
+ # narrow their types. (For example, we shouldn't try narrowing the
3793
+ # types of literal string or enum expressions).
3794
+
3795
+ operands = node .operands
3796
+ operand_types = []
3797
+ narrowable_operand_indices = set ()
3798
+ for i , expr in enumerate (operands ):
3799
+ if expr not in type_map :
3800
+ return {}, {}
3801
+ expr_type = type_map [expr ]
3802
+ operand_types .append (expr_type )
3803
+
3804
+ if (literal (expr ) == LITERAL_TYPE
3805
+ and not is_literal_none (expr )
3806
+ and not is_literal_enum (type_map , expr )):
3807
+ narrowable_operand_indices .add (i )
3808
+
3809
+ # Step 2: Group operands chained by either the 'is' or '==' operands
3810
+ # together. For all other operands, we keep them in groups of size 2.
3811
+ # So the expression:
3812
+ #
3813
+ # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
3814
+ #
3815
+ # ...is converted into the simplified operator list:
3816
+ #
3817
+ # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
3818
+ # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
3819
+ #
3820
+ # We group identity/equality expressions so we can propagate information
3821
+ # we discover about one operand across the entire chain. We don't bother
3822
+ # handling 'is not' and '!=' chains in a special way: those are very rare
3823
+ # in practice.
3824
+
3825
+ simplified_operator_list = [] # type: List[Tuple[str, List[int]]]
3826
+ last_operator = node .operators [0 ]
3827
+ current_group = set () # type: Set[int]
3828
+ for i , (operator , left_expr , right_expr ) in enumerate (node .pairwise ()):
3829
+ if current_group and (operator != last_operator or operator not in {'is' , '==' }):
3830
+ simplified_operator_list .append ((last_operator , sorted (current_group )))
3831
+ last_operator = operator
3832
+ current_group = set ()
3833
+
3834
+ # Note: 'i' corresponds to the left operand index, so 'i + 1' is the
3835
+ # right operand.
3836
+ current_group .add (i )
3837
+ current_group .add (i + 1 )
3838
+
3839
+ simplified_operator_list .append ((last_operator , sorted (current_group )))
3840
+
3841
+ # Step 3: Analyze each group and infer more precise type maps for each
3842
+ # assignable operand, if possible. We combine these type maps together
3843
+ # in the final step.
3844
+
3845
+ partial_type_maps = []
3846
+ for operator , expr_indices in simplified_operator_list :
3847
+ if operator in {'is' , 'is not' }:
3848
+ if_map , else_map = self .refine_identity_comparison_expression (
3849
+ operands ,
3850
+ operand_types ,
3851
+ expr_indices ,
3852
+ narrowable_operand_indices ,
3853
+ )
3854
+ elif operator in {'==' , '!=' }:
3855
+ if_map , else_map = self .refine_equality_comparison_expression (
3856
+ operands ,
3857
+ operand_types ,
3858
+ expr_indices ,
3859
+ narrowable_operand_indices ,
3860
+ )
3861
+ elif operator in {'in' , 'not in' }:
3862
+ assert len (expr_indices ) == 2
3863
+ left_index , right_index = expr_indices
3864
+ if left_index not in narrowable_operand_indices :
3865
+ continue
3820
3866
3821
- if is_not :
3822
- if_vars , else_vars = else_vars , if_vars
3823
- return if_vars , else_vars
3824
- # Check for `x == y` where x is of type Optional[T] and y is of type T
3825
- # or a type that overlaps with T (or vice versa).
3826
- elif node .operators == ['==' ]:
3827
- first_type = type_map [node .operands [0 ]]
3828
- second_type = type_map [node .operands [1 ]]
3829
- if is_optional (first_type ) != is_optional (second_type ):
3830
- if is_optional (first_type ):
3831
- optional_type , comp_type = first_type , second_type
3832
- optional_expr = node .operands [0 ]
3867
+ item_type = operand_types [left_index ]
3868
+ collection_type = operand_types [right_index ]
3869
+
3870
+ # We only try and narrow away 'None' for now
3871
+ if not is_optional (item_type ):
3872
+ pass
3873
+
3874
+ collection_item_type = get_proper_type (builtin_item_type (collection_type ))
3875
+ if collection_item_type is None or is_optional (collection_item_type ):
3876
+ continue
3877
+ if (isinstance (collection_item_type , Instance )
3878
+ and collection_item_type .type .fullname == 'builtins.object' ):
3879
+ continue
3880
+ if is_overlapping_erased_types (item_type , collection_item_type ):
3881
+ if_map , else_map = {operands [left_index ]: remove_optional (item_type )}, {}
3833
3882
else :
3834
- optional_type , comp_type = second_type , first_type
3835
- optional_expr = node .operands [1 ]
3836
- if is_overlapping_erased_types (optional_type , comp_type ):
3837
- return {optional_expr : remove_optional (optional_type )}, {}
3838
- elif node .operators in [['in' ], ['not in' ]]:
3839
- expr = node .operands [0 ]
3840
- left_type = type_map [expr ]
3841
- right_type = get_proper_type (builtin_item_type (type_map [node .operands [1 ]]))
3842
- right_ok = right_type and (not is_optional (right_type ) and
3843
- (not isinstance (right_type , Instance ) or
3844
- right_type .type .fullname != 'builtins.object' ))
3845
- if (right_type and right_ok and is_optional (left_type ) and
3846
- literal (expr ) == LITERAL_TYPE and not is_literal_none (expr ) and
3847
- is_overlapping_erased_types (left_type , right_type )):
3848
- if node .operators == ['in' ]:
3849
- return {expr : remove_optional (left_type )}, {}
3850
- if node .operators == ['not in' ]:
3851
- return {}, {expr : remove_optional (left_type )}
3883
+ continue
3884
+ else :
3885
+ if_map = {}
3886
+ else_map = {}
3887
+
3888
+ if operator in {'is not' , '!=' , 'not in' }:
3889
+ if_map , else_map = else_map , if_map
3890
+
3891
+ partial_type_maps .append ((if_map , else_map ))
3892
+
3893
+ return reduce_partial_type_maps (partial_type_maps )
3852
3894
elif isinstance (node , RefExpr ):
3853
3895
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
3854
3896
# respectively
@@ -4053,6 +4095,120 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
4053
4095
4054
4096
return output
4055
4097
4098
+ def refine_identity_comparison_expression (self ,
4099
+ operands : List [Expression ],
4100
+ operand_types : List [Type ],
4101
+ chain_indices : List [int ],
4102
+ narrowable_operand_indices : Set [int ],
4103
+ ) -> Tuple [TypeMap , TypeMap ]:
4104
+ """Produces conditional type maps refining expressions used in an identity comparison.
4105
+
4106
+ The 'operands' and 'operand_types' lists should be the full list of operands used
4107
+ in the overall comparison expression. The 'chain_indices' list is the list of indices
4108
+ actually used within this identity comparison chain.
4109
+
4110
+ So if we have the expression:
4111
+
4112
+ a <= b is c is d <= e
4113
+
4114
+ ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
4115
+ would be the list [1, 2, 3].
4116
+
4117
+ The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
4118
+ to refine the types of: that is, all operands that will potentially be a part of
4119
+ the output TypeMaps.
4120
+ """
4121
+ singleton = None # type: Optional[ProperType]
4122
+ possible_singleton_indices = []
4123
+ for i in chain_indices :
4124
+ coerced_type = coerce_to_literal (operand_types [i ])
4125
+ if not is_singleton_type (coerced_type ):
4126
+ continue
4127
+ if singleton and not is_same_type (singleton , coerced_type ):
4128
+ # We have multiple disjoint singleton types. So the 'if' branch
4129
+ # must be unreachable.
4130
+ return None , {}
4131
+ singleton = coerced_type
4132
+ possible_singleton_indices .append (i )
4133
+
4134
+ # There's nothing we can currently infer if none of the operands are singleton types,
4135
+ # so we end early and infer nothing.
4136
+ if singleton is None :
4137
+ return {}, {}
4138
+
4139
+ # If possible, use an unassignable expression as the singleton.
4140
+ # We skip refining the type of the singleton below, so ideally we'd
4141
+ # want to pick an expression we were going to skip anyways.
4142
+ singleton_index = - 1
4143
+ for i in possible_singleton_indices :
4144
+ if i not in narrowable_operand_indices :
4145
+ singleton_index = i
4146
+
4147
+ # Oh well, give up and just arbitrarily pick the last item.
4148
+ if singleton_index == - 1 :
4149
+ singleton_index = possible_singleton_indices [- 1 ]
4150
+
4151
+ enum_name = None
4152
+ if isinstance (singleton , LiteralType ) and singleton .is_enum_literal ():
4153
+ enum_name = singleton .fallback .type .fullname
4154
+
4155
+ target_type = [TypeRange (singleton , is_upper_bound = False )]
4156
+
4157
+ partial_type_maps = []
4158
+ for i in chain_indices :
4159
+ # If we try refining a singleton against itself, conditional_type_map
4160
+ # will end up assuming that the 'else' branch is unreachable. This is
4161
+ # typically not what we want: generally the user will intend for the
4162
+ # singleton type to be some fixed 'sentinel' value and will want to refine
4163
+ # the other exprs against this one instead.
4164
+ if i == singleton_index :
4165
+ continue
4166
+
4167
+ # Naturally, we can't refine operands which are not permitted to be refined.
4168
+ if i not in narrowable_operand_indices :
4169
+ continue
4170
+
4171
+ expr = operands [i ]
4172
+ expr_type = coerce_to_literal (operand_types [i ])
4173
+
4174
+ if enum_name is not None :
4175
+ expr_type = try_expanding_enum_to_union (expr_type , enum_name )
4176
+ partial_type_maps .append (conditional_type_map (expr , expr_type , target_type ))
4177
+
4178
+ return reduce_partial_type_maps (partial_type_maps )
4179
+
4180
+ def refine_equality_comparison_expression (self ,
4181
+ operands : List [Expression ],
4182
+ operand_types : List [Type ],
4183
+ chain_indices : List [int ],
4184
+ narrowable_operand_indices : Set [int ],
4185
+ ) -> Tuple [TypeMap , TypeMap ]:
4186
+ """Produces conditional type maps refining expressions used in an equality comparison.
4187
+
4188
+ For more details, see the docstring of 'refine_equality_comparison' up above.
4189
+ The only difference is that this function is for refining equality operations
4190
+ (e.g. 'a == b == c') instead of identity ('a is b is c').
4191
+ """
4192
+ non_optional_types = []
4193
+ for i in chain_indices :
4194
+ typ = operand_types [i ]
4195
+ if not is_optional (typ ):
4196
+ non_optional_types .append (typ )
4197
+
4198
+ # Make sure we have a mixture of optional and non-optional types.
4199
+ if len (non_optional_types ) == 0 or len (non_optional_types ) == len (chain_indices ):
4200
+ return {}, {}
4201
+
4202
+ if_map = {}
4203
+ for i in narrowable_operand_indices :
4204
+ expr_type = operand_types [i ]
4205
+ if not is_optional (expr_type ):
4206
+ continue
4207
+ if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
4208
+ if_map [operands [i ]] = remove_optional (expr_type )
4209
+
4210
+ return if_map , {}
4211
+
4056
4212
#
4057
4213
# Helpers
4058
4214
#
@@ -4496,6 +4652,26 @@ def is_false_literal(n: Expression) -> bool:
4496
4652
or isinstance (n , IntExpr ) and n .value == 0 )
4497
4653
4498
4654
4655
+ def is_literal_enum (type_map : Mapping [Expression , Type ], n : Expression ) -> bool :
4656
+ if not isinstance (n , MemberExpr ) or not isinstance (n .expr , NameExpr ):
4657
+ return False
4658
+
4659
+ parent_type = type_map .get (n .expr )
4660
+ member_type = type_map .get (n )
4661
+ if member_type is None or parent_type is None :
4662
+ return False
4663
+
4664
+ parent_type = get_proper_type (parent_type )
4665
+ member_type = coerce_to_literal (member_type )
4666
+ if not isinstance (parent_type , FunctionLike ) or not isinstance (member_type , LiteralType ):
4667
+ return False
4668
+
4669
+ if not parent_type .is_type_obj ():
4670
+ return False
4671
+
4672
+ return member_type .is_enum_literal () and member_type .fallback .type == parent_type .type_object ()
4673
+
4674
+
4499
4675
def is_literal_none (n : Expression ) -> bool :
4500
4676
return isinstance (n , NameExpr ) and n .fullname == 'builtins.None'
4501
4677
@@ -4587,6 +4763,75 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
4587
4763
return result
4588
4764
4589
4765
4766
+ def or_partial_conditional_maps (m1 : TypeMap , m2 : TypeMap ) -> TypeMap :
4767
+ """Calculate what information we can learn from the truth of (e1 or e2)
4768
+ in terms of the information that we can learn from the truth of e1 and
4769
+ the truth of e2.
4770
+
4771
+ Unlike 'or_conditional_maps', we include an expression in the output even
4772
+ if it exists in only one map: we're assuming both maps are "partial" and
4773
+ contain information about only some expressions, and so we "or" together
4774
+ expressions both maps have information on.
4775
+ """
4776
+
4777
+ if m1 is None :
4778
+ return m2
4779
+ if m2 is None :
4780
+ return m1
4781
+ # The logic here is a blend between 'and_conditional_maps'
4782
+ # and 'or_conditional_maps'. We use the high-level logic from the
4783
+ # former to ensure all expressions make it in the output map,
4784
+ # but resolve cases where both maps contain info on the same
4785
+ # expr using the unioning strategy from the latter.
4786
+ result = m2 .copy ()
4787
+ m2_keys = {literal_hash (n2 ): n2 for n2 in m2 }
4788
+ for n1 in m1 :
4789
+ n2 = m2_keys .get (literal_hash (n1 ))
4790
+ if n2 is None :
4791
+ result [n1 ] = m1 [n1 ]
4792
+ else :
4793
+ result [n2 ] = make_simplified_union ([m1 [n1 ], result [n2 ]])
4794
+
4795
+ return result
4796
+
4797
+
4798
+ def reduce_partial_type_maps (type_maps : List [Tuple [TypeMap , TypeMap ]]) -> Tuple [TypeMap , TypeMap ]:
4799
+ """Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair.
4800
+
4801
+ That is, if a expression exists in only one map, we always include it in the output.
4802
+ We only "and"/"or" together expressions that appear in multiple if/else maps.
4803
+
4804
+ So for example, if we had the input:
4805
+
4806
+ [
4807
+ ({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}),
4808
+ ({y: TypeIfY, shared: TypeIfShared2}, {y: TypeElseY, shared: TypeElseShared2}),
4809
+ ]
4810
+
4811
+ ...we'd return the output:
4812
+
4813
+ (
4814
+ {x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]},
4815
+ {x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]},
4816
+ )
4817
+
4818
+ ...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections
4819
+ yet, so we settle for just arbitrarily picking the right expr's type.
4820
+ """
4821
+ if len (type_maps ) == 0 :
4822
+ return {}, {}
4823
+ elif len (type_maps ) == 1 :
4824
+ return type_maps [0 ]
4825
+ else :
4826
+ final_if_map , final_else_map = type_maps [0 ]
4827
+ for if_map , else_map in type_maps [1 :]:
4828
+ # 'and_conditional_maps' does the same thing for both global and partial type maps,
4829
+ # which is why we don't need to have an 'and_partial_conditional_maps' function.
4830
+ final_if_map = and_conditional_maps (final_if_map , if_map )
4831
+ final_else_map = or_partial_conditional_maps (final_else_map , else_map )
4832
+ return final_if_map , final_else_map
4833
+
4834
+
4590
4835
def convert_to_typetype (type_map : TypeMap ) -> TypeMap :
4591
4836
converted_type_map = {} # type: Dict[Expression, Type]
4592
4837
if type_map is None :
0 commit comments