Skip to content

Commit 91d5bf9

Browse files
committed
Make reachability code understand chained comparisons (v2)
This pull request is v2 (well, more like v10...) of my attempts to make our reachability code better understand chained comparisons. Unlike #7169, this diff focuses exclusively on adding support for chained operation comparisons and deliberately does not attempt to change any of the semantics of how identity and equality operations are performed. Specifically, mypy currently only examines the first two operands within a comparison expression when refining types. That means the following expressions all do not behave as expected: ```python x: MyEnum y: MyEnum if x is y is MyEnum.A: # x and y are not narrowed at all if x is MyEnum.A is y: # Only x is narrowed to Literal[MyEnum.A] ``` This pull request fixes this so we correctly infer the literal type for x and y in both conditionals. Some additional notes: 1. While analyzing our codebase, I found that while comparison expressions involving two or more `is` or `==` operators were somewhat common, there were almost no comparisons involving chains of `!=` or `is not` operators, and no comparisons involving "disjoint chains" -- e.g. expressions like `a == b < c == b` where there are multiple "disjoint" chains of equality comparisons. So, this diff is primarily designed to handle the case where a comparision expression has just one chain of `is` or `==`. For all other cases, I fall back to the more naive strategy of evaluating each comparision individually and and-ing the inferred types together without attempting to propagate any info. 2. I tested this code against one of our internal codebases. This ended up making mypy produce 3 or 4 new errors, but they all seemed legitimate, as far as I can tell. 3. I plan on submitting a follow-up diff that takes advantage of the work done in this diff to complete support for tagged unions using any Literal key, as previously promised. (I tried adding support for tagged unions in this diff, but attempting to simultaneously add support for chained comparisons while overhauling the semantics of `==` proved to be a little too overwhelming for me. So, baby steps.)
1 parent a918ce8 commit 91d5bf9

File tree

3 files changed

+427
-59
lines changed

3 files changed

+427
-59
lines changed

mypy/checker.py

Lines changed: 304 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3788,67 +3788,109 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
37883788
vartype = type_map[expr]
37893789
return self.conditional_callable_type_map(expr, vartype)
37903790
elif isinstance(node, ComparisonExpr):
3791-
operand_types = [coerce_to_literal(type_map[expr])
3792-
for expr in node.operands if expr in type_map]
3793-
3794-
is_not = node.operators == ['is not']
3795-
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
3796-
if_vars = {} # type: TypeMap
3797-
else_vars = {} # type: TypeMap
3798-
3799-
for i, expr in enumerate(node.operands):
3800-
var_type = operand_types[i]
3801-
other_type = operand_types[1 - i]
3802-
3803-
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
3804-
# This should only be true at most once: there should be
3805-
# exactly two elements in node.operands and if the 'other type' is
3806-
# a singleton type, it by definition does not need to be narrowed:
3807-
# it already has the most precise type possible so does not need to
3808-
# be narrowed/included in the output map.
3809-
#
3810-
# TODO: Generalize this to handle the case where 'other_type' is
3811-
# a union of singleton types.
3812-
3813-
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3814-
fallback_name = other_type.fallback.type.fullname
3815-
var_type = try_expanding_enum_to_union(var_type, fallback_name)
3816-
3817-
target_type = [TypeRange(other_type, is_upper_bound=False)]
3818-
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
3819-
break
3791+
# Step 1: Obtain the types of each operand and whether or not we can
3792+
# narrow their types. (For example, we shouldn't try narrowing the
3793+
# types of literal string or enum expressions).
3794+
3795+
operands = node.operands
3796+
operand_types = []
3797+
narrowable_operand_indices = set()
3798+
for i, expr in enumerate(operands):
3799+
if expr not in type_map:
3800+
return {}, {}
3801+
expr_type = type_map[expr]
3802+
operand_types.append(expr_type)
3803+
3804+
if (literal(expr) == LITERAL_TYPE
3805+
and not is_literal_none(expr)
3806+
and not is_literal_enum(type_map, expr)):
3807+
narrowable_operand_indices.add(i)
3808+
3809+
# Step 2: Group operands chained by either the 'is' or '==' operands
3810+
# together. For all other operands, we keep them in groups of size 2.
3811+
# So the expression:
3812+
#
3813+
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
3814+
#
3815+
# ...is converted into the simplified operator list:
3816+
#
3817+
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
3818+
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
3819+
#
3820+
# We group identity/equality expressions so we can propagate information
3821+
# we discover about one operand across the entire chain. We don't bother
3822+
# handling 'is not' and '!=' chains in a special way: those are very rare
3823+
# in practice.
3824+
3825+
simplified_operator_list = [] # type: List[Tuple[str, List[int]]]
3826+
last_operator = node.operators[0]
3827+
current_group = set() # type: Set[int]
3828+
for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()):
3829+
if current_group and (operator != last_operator or operator not in {'is', '=='}):
3830+
simplified_operator_list.append((last_operator, sorted(current_group)))
3831+
last_operator = operator
3832+
current_group = set()
3833+
3834+
# Note: 'i' corresponds to the left operand index, so 'i + 1' is the
3835+
# right operand.
3836+
current_group.add(i)
3837+
current_group.add(i + 1)
3838+
3839+
simplified_operator_list.append((last_operator, sorted(current_group)))
3840+
3841+
# Step 3: Analyze each group and infer more precise type maps for each
3842+
# assignable operand, if possible. We combine these type maps together
3843+
# in the final step.
3844+
3845+
partial_type_maps = []
3846+
for operator, expr_indices in simplified_operator_list:
3847+
if operator in {'is', 'is not'}:
3848+
if_map, else_map = self.refine_identity_comparison_expression(
3849+
operands,
3850+
operand_types,
3851+
expr_indices,
3852+
narrowable_operand_indices,
3853+
)
3854+
elif operator in {'==', '!='}:
3855+
if_map, else_map = self.refine_equality_comparison_expression(
3856+
operands,
3857+
operand_types,
3858+
expr_indices,
3859+
narrowable_operand_indices,
3860+
)
3861+
elif operator in {'in', 'not in'}:
3862+
assert len(expr_indices) == 2
3863+
left_index, right_index = expr_indices
3864+
if left_index not in narrowable_operand_indices:
3865+
continue
38203866

3821-
if is_not:
3822-
if_vars, else_vars = else_vars, if_vars
3823-
return if_vars, else_vars
3824-
# Check for `x == y` where x is of type Optional[T] and y is of type T
3825-
# or a type that overlaps with T (or vice versa).
3826-
elif node.operators == ['==']:
3827-
first_type = type_map[node.operands[0]]
3828-
second_type = type_map[node.operands[1]]
3829-
if is_optional(first_type) != is_optional(second_type):
3830-
if is_optional(first_type):
3831-
optional_type, comp_type = first_type, second_type
3832-
optional_expr = node.operands[0]
3867+
item_type = operand_types[left_index]
3868+
collection_type = operand_types[right_index]
3869+
3870+
# We only try and narrow away 'None' for now
3871+
if not is_optional(item_type):
3872+
pass
3873+
3874+
collection_item_type = get_proper_type(builtin_item_type(collection_type))
3875+
if collection_item_type is None or is_optional(collection_item_type):
3876+
continue
3877+
if (isinstance(collection_item_type, Instance)
3878+
and collection_item_type.type.fullname == 'builtins.object'):
3879+
continue
3880+
if is_overlapping_erased_types(item_type, collection_item_type):
3881+
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
38333882
else:
3834-
optional_type, comp_type = second_type, first_type
3835-
optional_expr = node.operands[1]
3836-
if is_overlapping_erased_types(optional_type, comp_type):
3837-
return {optional_expr: remove_optional(optional_type)}, {}
3838-
elif node.operators in [['in'], ['not in']]:
3839-
expr = node.operands[0]
3840-
left_type = type_map[expr]
3841-
right_type = get_proper_type(builtin_item_type(type_map[node.operands[1]]))
3842-
right_ok = right_type and (not is_optional(right_type) and
3843-
(not isinstance(right_type, Instance) or
3844-
right_type.type.fullname != 'builtins.object'))
3845-
if (right_type and right_ok and is_optional(left_type) and
3846-
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
3847-
is_overlapping_erased_types(left_type, right_type)):
3848-
if node.operators == ['in']:
3849-
return {expr: remove_optional(left_type)}, {}
3850-
if node.operators == ['not in']:
3851-
return {}, {expr: remove_optional(left_type)}
3883+
continue
3884+
else:
3885+
if_map = {}
3886+
else_map = {}
3887+
3888+
if operator in {'is not', '!=', 'not in'}:
3889+
if_map, else_map = else_map, if_map
3890+
3891+
partial_type_maps.append((if_map, else_map))
3892+
3893+
return reduce_partial_type_maps(partial_type_maps)
38523894
elif isinstance(node, RefExpr):
38533895
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
38543896
# respectively
@@ -4053,6 +4095,120 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
40534095

40544096
return output
40554097

4098+
def refine_identity_comparison_expression(self,
4099+
operands: List[Expression],
4100+
operand_types: List[Type],
4101+
chain_indices: List[int],
4102+
narrowable_operand_indices: Set[int],
4103+
) -> Tuple[TypeMap, TypeMap]:
4104+
"""Produces conditional type maps refining expressions used in an identity comparison.
4105+
4106+
The 'operands' and 'operand_types' lists should be the full list of operands used
4107+
in the overall comparison expression. The 'chain_indices' list is the list of indices
4108+
actually used within this identity comparison chain.
4109+
4110+
So if we have the expression:
4111+
4112+
a <= b is c is d <= e
4113+
4114+
...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
4115+
would be the list [1, 2, 3].
4116+
4117+
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
4118+
to refine the types of: that is, all operands that will potentially be a part of
4119+
the output TypeMaps.
4120+
"""
4121+
singleton = None # type: Optional[ProperType]
4122+
possible_singleton_indices = []
4123+
for i in chain_indices:
4124+
coerced_type = coerce_to_literal(operand_types[i])
4125+
if not is_singleton_type(coerced_type):
4126+
continue
4127+
if singleton and not is_same_type(singleton, coerced_type):
4128+
# We have multiple disjoint singleton types. So the 'if' branch
4129+
# must be unreachable.
4130+
return None, {}
4131+
singleton = coerced_type
4132+
possible_singleton_indices.append(i)
4133+
4134+
# There's nothing we can currently infer if none of the operands are singleton types,
4135+
# so we end early and infer nothing.
4136+
if singleton is None:
4137+
return {}, {}
4138+
4139+
# If possible, use an unassignable expression as the singleton.
4140+
# We skip refining the type of the singleton below, so ideally we'd
4141+
# want to pick an expression we were going to skip anyways.
4142+
singleton_index = -1
4143+
for i in possible_singleton_indices:
4144+
if i not in narrowable_operand_indices:
4145+
singleton_index = i
4146+
4147+
# Oh well, give up and just arbitrarily pick the last item.
4148+
if singleton_index == -1:
4149+
singleton_index = possible_singleton_indices[-1]
4150+
4151+
enum_name = None
4152+
if isinstance(singleton, LiteralType) and singleton.is_enum_literal():
4153+
enum_name = singleton.fallback.type.fullname
4154+
4155+
target_type = [TypeRange(singleton, is_upper_bound=False)]
4156+
4157+
partial_type_maps = []
4158+
for i in chain_indices:
4159+
# If we try refining a singleton against itself, conditional_type_map
4160+
# will end up assuming that the 'else' branch is unreachable. This is
4161+
# typically not what we want: generally the user will intend for the
4162+
# singleton type to be some fixed 'sentinel' value and will want to refine
4163+
# the other exprs against this one instead.
4164+
if i == singleton_index:
4165+
continue
4166+
4167+
# Naturally, we can't refine operands which are not permitted to be refined.
4168+
if i not in narrowable_operand_indices:
4169+
continue
4170+
4171+
expr = operands[i]
4172+
expr_type = coerce_to_literal(operand_types[i])
4173+
4174+
if enum_name is not None:
4175+
expr_type = try_expanding_enum_to_union(expr_type, enum_name)
4176+
partial_type_maps.append(conditional_type_map(expr, expr_type, target_type))
4177+
4178+
return reduce_partial_type_maps(partial_type_maps)
4179+
4180+
def refine_equality_comparison_expression(self,
4181+
operands: List[Expression],
4182+
operand_types: List[Type],
4183+
chain_indices: List[int],
4184+
narrowable_operand_indices: Set[int],
4185+
) -> Tuple[TypeMap, TypeMap]:
4186+
"""Produces conditional type maps refining expressions used in an equality comparison.
4187+
4188+
For more details, see the docstring of 'refine_equality_comparison' up above.
4189+
The only difference is that this function is for refining equality operations
4190+
(e.g. 'a == b == c') instead of identity ('a is b is c').
4191+
"""
4192+
non_optional_types = []
4193+
for i in chain_indices:
4194+
typ = operand_types[i]
4195+
if not is_optional(typ):
4196+
non_optional_types.append(typ)
4197+
4198+
# Make sure we have a mixture of optional and non-optional types.
4199+
if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices):
4200+
return {}, {}
4201+
4202+
if_map = {}
4203+
for i in narrowable_operand_indices:
4204+
expr_type = operand_types[i]
4205+
if not is_optional(expr_type):
4206+
continue
4207+
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
4208+
if_map[operands[i]] = remove_optional(expr_type)
4209+
4210+
return if_map, {}
4211+
40564212
#
40574213
# Helpers
40584214
#
@@ -4496,6 +4652,26 @@ def is_false_literal(n: Expression) -> bool:
44964652
or isinstance(n, IntExpr) and n.value == 0)
44974653

44984654

4655+
def is_literal_enum(type_map: Mapping[Expression, Type], n: Expression) -> bool:
4656+
if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr):
4657+
return False
4658+
4659+
parent_type = type_map.get(n.expr)
4660+
member_type = type_map.get(n)
4661+
if member_type is None or parent_type is None:
4662+
return False
4663+
4664+
parent_type = get_proper_type(parent_type)
4665+
member_type = coerce_to_literal(member_type)
4666+
if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType):
4667+
return False
4668+
4669+
if not parent_type.is_type_obj():
4670+
return False
4671+
4672+
return member_type.is_enum_literal() and member_type.fallback.type == parent_type.type_object()
4673+
4674+
44994675
def is_literal_none(n: Expression) -> bool:
45004676
return isinstance(n, NameExpr) and n.fullname == 'builtins.None'
45014677

@@ -4587,6 +4763,75 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
45874763
return result
45884764

45894765

4766+
def or_partial_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
4767+
"""Calculate what information we can learn from the truth of (e1 or e2)
4768+
in terms of the information that we can learn from the truth of e1 and
4769+
the truth of e2.
4770+
4771+
Unlike 'or_conditional_maps', we include an expression in the output even
4772+
if it exists in only one map: we're assuming both maps are "partial" and
4773+
contain information about only some expressions, and so we "or" together
4774+
expressions both maps have information on.
4775+
"""
4776+
4777+
if m1 is None:
4778+
return m2
4779+
if m2 is None:
4780+
return m1
4781+
# The logic here is a blend between 'and_conditional_maps'
4782+
# and 'or_conditional_maps'. We use the high-level logic from the
4783+
# former to ensure all expressions make it in the output map,
4784+
# but resolve cases where both maps contain info on the same
4785+
# expr using the unioning strategy from the latter.
4786+
result = m2.copy()
4787+
m2_keys = {literal_hash(n2): n2 for n2 in m2}
4788+
for n1 in m1:
4789+
n2 = m2_keys.get(literal_hash(n1))
4790+
if n2 is None:
4791+
result[n1] = m1[n1]
4792+
else:
4793+
result[n2] = make_simplified_union([m1[n1], result[n2]])
4794+
4795+
return result
4796+
4797+
4798+
def reduce_partial_type_maps(type_maps: List[Tuple[TypeMap, TypeMap]]) -> Tuple[TypeMap, TypeMap]:
4799+
"""Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair.
4800+
4801+
That is, if a expression exists in only one map, we always include it in the output.
4802+
We only "and"/"or" together expressions that appear in multiple if/else maps.
4803+
4804+
So for example, if we had the input:
4805+
4806+
[
4807+
({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}),
4808+
({y: TypeIfY, shared: TypeIfShared2}, {y: TypeElseY, shared: TypeElseShared2}),
4809+
]
4810+
4811+
...we'd return the output:
4812+
4813+
(
4814+
{x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]},
4815+
{x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]},
4816+
)
4817+
4818+
...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections
4819+
yet, so we settle for just arbitrarily picking the right expr's type.
4820+
"""
4821+
if len(type_maps) == 0:
4822+
return {}, {}
4823+
elif len(type_maps) == 1:
4824+
return type_maps[0]
4825+
else:
4826+
final_if_map, final_else_map = type_maps[0]
4827+
for if_map, else_map in type_maps[1:]:
4828+
# 'and_conditional_maps' does the same thing for both global and partial type maps,
4829+
# which is why we don't need to have an 'and_partial_conditional_maps' function.
4830+
final_if_map = and_conditional_maps(final_if_map, if_map)
4831+
final_else_map = or_partial_conditional_maps(final_else_map, else_map)
4832+
return final_if_map, final_else_map
4833+
4834+
45904835
def convert_to_typetype(type_map: TypeMap) -> TypeMap:
45914836
converted_type_map = {} # type: Dict[Expression, Type]
45924837
if type_map is None:

0 commit comments

Comments
 (0)