Skip to content

Commit 226661f

Browse files
JukkaLfreundTech
andauthored
Exhaustiveness checking for match statements (#12267)
Closes #12010. Mypy can now detect if a match statement covers all the possible values. Example: ``` def f(x: int | str) -> int: match x: case str(): return 0 case int(): return 1 # Mypy knows that we can't reach here ``` Most of the work was done by @freundTech. I did various minor updates and changes to tests. This doesn't handle some cases properly, including these: 1. We don't recognize that `match [*args]` fully covers a list type 2. Fake intersections don't work quite right (some tests are skipped) 3. We assume enums don't have custom `__eq__` methods Co-authored-by: Adrian Freund <[email protected]>
1 parent fce1b54 commit 226661f

File tree

4 files changed

+520
-142
lines changed

4 files changed

+520
-142
lines changed

mypy/checker.py

+62-23
Original file line numberDiff line numberDiff line change
@@ -4089,36 +4089,57 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
40894089
if isinstance(subject_type, DeletedType):
40904090
self.msg.deleted_as_rvalue(subject_type, s)
40914091

4092+
# We infer types of patterns twice. The first pass is used
4093+
# to infer the types of capture variables. The type of a
4094+
# capture variable may depend on multiple patterns (it
4095+
# will be a union of all capture types). This pass ignores
4096+
# guard expressions.
40924097
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
4093-
40944098
type_maps: List[TypeMap] = [t.captures for t in pattern_types]
4095-
self.infer_variable_types_from_type_maps(type_maps)
4099+
inferred_types = self.infer_variable_types_from_type_maps(type_maps)
40964100

4097-
for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies):
4101+
# The second pass narrows down the types and type checks bodies.
4102+
for p, g, b in zip(s.patterns, s.guards, s.bodies):
4103+
current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject,
4104+
subject_type)
4105+
pattern_type = self.pattern_checker.accept(p, current_subject_type)
40984106
with self.binder.frame_context(can_skip=True, fall_through=2):
40994107
if b.is_unreachable or isinstance(get_proper_type(pattern_type.type),
41004108
UninhabitedType):
41014109
self.push_type_map(None)
4110+
else_map: TypeMap = {}
41024111
else:
4103-
self.binder.put(s.subject, pattern_type.type)
4112+
pattern_map, else_map = conditional_types_to_typemaps(
4113+
s.subject,
4114+
pattern_type.type,
4115+
pattern_type.rest_type
4116+
)
4117+
self.remove_capture_conflicts(pattern_type.captures,
4118+
inferred_types)
4119+
self.push_type_map(pattern_map)
41044120
self.push_type_map(pattern_type.captures)
41054121
if g is not None:
4106-
gt = get_proper_type(self.expr_checker.accept(g))
4122+
with self.binder.frame_context(can_skip=True, fall_through=3):
4123+
gt = get_proper_type(self.expr_checker.accept(g))
41074124

4108-
if isinstance(gt, DeletedType):
4109-
self.msg.deleted_as_rvalue(gt, s)
4125+
if isinstance(gt, DeletedType):
4126+
self.msg.deleted_as_rvalue(gt, s)
41104127

4111-
if_map, _ = self.find_isinstance_check(g)
4128+
guard_map, guard_else_map = self.find_isinstance_check(g)
4129+
else_map = or_conditional_maps(else_map, guard_else_map)
41124130

4113-
self.push_type_map(if_map)
4114-
self.accept(b)
4131+
self.push_type_map(guard_map)
4132+
self.accept(b)
4133+
else:
4134+
self.accept(b)
4135+
self.push_type_map(else_map)
41154136

41164137
# This is needed due to a quirk in frame_context. Without it types will stay narrowed
41174138
# after the match.
41184139
with self.binder.frame_context(can_skip=False, fall_through=2):
41194140
pass
41204141

4121-
def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
4142+
def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]:
41224143
all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list)
41234144
for tm in type_maps:
41244145
if tm is not None:
@@ -4128,28 +4149,38 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
41284149
assert isinstance(node, Var)
41294150
all_captures[node].append((expr, typ))
41304151

4152+
inferred_types: Dict[Var, Type] = {}
41314153
for var, captures in all_captures.items():
4132-
conflict = False
4154+
already_exists = False
41334155
types: List[Type] = []
41344156
for expr, typ in captures:
41354157
types.append(typ)
41364158

4137-
previous_type, _, inferred = self.check_lvalue(expr)
4159+
previous_type, _, _ = self.check_lvalue(expr)
41384160
if previous_type is not None:
4139-
conflict = True
4140-
self.check_subtype(typ, previous_type, expr,
4141-
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
4142-
subtype_label="pattern captures type",
4143-
supertype_label="variable has type")
4144-
for type_map in type_maps:
4145-
if type_map is not None and expr in type_map:
4146-
del type_map[expr]
4147-
4148-
if not conflict:
4161+
already_exists = True
4162+
if self.check_subtype(typ, previous_type, expr,
4163+
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
4164+
subtype_label="pattern captures type",
4165+
supertype_label="variable has type"):
4166+
inferred_types[var] = previous_type
4167+
4168+
if not already_exists:
41494169
new_type = UnionType.make_union(types)
41504170
# Infer the union type at the first occurrence
41514171
first_occurrence, _ = captures[0]
4172+
inferred_types[var] = new_type
41524173
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
4174+
return inferred_types
4175+
4176+
def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None:
4177+
if type_map:
4178+
for expr, typ in list(type_map.items()):
4179+
if isinstance(expr, NameExpr):
4180+
node = expr.node
4181+
assert isinstance(node, Var)
4182+
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
4183+
del type_map[expr]
41534184

41544185
def make_fake_typeinfo(self,
41554186
curr_module_fullname: str,
@@ -5637,6 +5668,14 @@ def conditional_types(current_type: Type,
56375668
None means no new information can be inferred. If default is set it is returned
56385669
instead."""
56395670
if proposed_type_ranges:
5671+
if len(proposed_type_ranges) == 1:
5672+
target = proposed_type_ranges[0].item
5673+
target = get_proper_type(target)
5674+
if isinstance(target, LiteralType) and (target.is_enum_literal()
5675+
or isinstance(target.value, bool)):
5676+
enum_name = target.fallback.type.fullname
5677+
current_type = try_expanding_sum_type_to_union(current_type,
5678+
enum_name)
56405679
proposed_items = [type_range.item for type_range in proposed_type_ranges]
56415680
proposed_type = make_simplified_union(proposed_items)
56425681
if isinstance(proposed_type, AnyType):

mypy/checkpattern.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Pattern checker. This file is conceptually part of TypeChecker."""
2+
23
from collections import defaultdict
34
from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union
45
from typing_extensions import Final
@@ -19,7 +20,8 @@
1920
)
2021
from mypy.plugin import Plugin
2122
from mypy.subtypes import is_subtype
22-
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union
23+
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \
24+
coerce_to_literal
2325
from mypy.types import (
2426
ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type,
2527
TypedDictType, TupleType, NoneType, UnionType
@@ -55,7 +57,7 @@
5557
'PatternType',
5658
[
5759
('type', Type), # The type the match subject can be narrowed to
58-
('rest_type', Type), # For exhaustiveness checking. Not used yet
60+
('rest_type', Type), # The remaining type if the pattern didn't match
5961
('captures', Dict[Expression, Type]), # The variables captured by the pattern
6062
])
6163

@@ -177,6 +179,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
177179
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
178180
current_type = self.type_context[-1]
179181
typ = self.chk.expr_checker.accept(o.expr)
182+
typ = coerce_to_literal(typ)
180183
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
181184
current_type,
182185
[get_type_range(typ)],
@@ -259,6 +262,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
259262
new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types,
260263
star_position,
261264
len(inner_types))
265+
rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types,
266+
star_position,
267+
len(inner_types))
262268

263269
#
264270
# Calculate new type
@@ -287,15 +293,20 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
287293

288294
if all(is_uninhabited(typ) for typ in inner_rest_types):
289295
# All subpatterns always match, so we can apply negative narrowing
290-
new_type, rest_type = self.chk.conditional_types_with_intersection(
291-
current_type, [get_type_range(new_type)], o, default=current_type
292-
)
296+
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
293297
else:
294298
new_inner_type = UninhabitedType()
295299
for typ in new_inner_types:
296300
new_inner_type = join_types(new_inner_type, typ)
297301
new_type = self.construct_sequence_child(current_type, new_inner_type)
298-
if not is_subtype(new_type, current_type):
302+
if is_subtype(new_type, current_type):
303+
new_type, _ = self.chk.conditional_types_with_intersection(
304+
current_type,
305+
[get_type_range(new_type)],
306+
o,
307+
default=current_type
308+
)
309+
else:
299310
new_type = current_type
300311
return PatternType(new_type, rest_type, captures)
301312

@@ -344,8 +355,7 @@ def expand_starred_pattern_types(self,
344355
star_pos: Optional[int],
345356
num_types: int
346357
) -> List[Type]:
347-
"""
348-
Undoes the contraction done by contract_starred_pattern_types.
358+
"""Undoes the contraction done by contract_starred_pattern_types.
349359
350360
For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
351361
to lenght 4 the result is [bool, int, int, str].
@@ -639,6 +649,13 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
639649
For example:
640650
construct_sequence_child(List[int], str) = List[str]
641651
"""
652+
proper_type = get_proper_type(outer_type)
653+
if isinstance(proper_type, UnionType):
654+
types = [
655+
self.construct_sequence_child(item, inner_type) for item in proper_type.items
656+
if self.can_match_sequence(get_proper_type(item))
657+
]
658+
return make_simplified_union(types)
642659
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
643660
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
644661
proper_type = get_proper_type(outer_type)
@@ -676,6 +693,11 @@ def get_var(expr: Expression) -> Var:
676693

677694

678695
def get_type_range(typ: Type) -> 'mypy.checker.TypeRange':
696+
typ = get_proper_type(typ)
697+
if (isinstance(typ, Instance)
698+
and typ.last_known_value
699+
and isinstance(typ.last_known_value.value, bool)):
700+
typ = typ.last_known_value
679701
return mypy.checker.TypeRange(typ, is_upper_bound=False)
680702

681703

mypy/patterns.py

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:
2121

2222

2323
class AsPattern(Pattern):
24+
"""The pattern <pattern> as <name>"""
2425
# The python ast, and therefore also our ast merges capture, wildcard and as patterns into one
2526
# for easier handling.
2627
# If pattern is None this is a capture pattern. If name and pattern are both none this is a
@@ -39,6 +40,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:
3940

4041

4142
class OrPattern(Pattern):
43+
"""The pattern <pattern> | <pattern> | ..."""
4244
patterns: List[Pattern]
4345

4446
def __init__(self, patterns: List[Pattern]) -> None:
@@ -50,6 +52,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:
5052

5153

5254
class ValuePattern(Pattern):
55+
"""The pattern x.y (or x.y.z, ...)"""
5356
expr: Expression
5457

5558
def __init__(self, expr: Expression):
@@ -73,6 +76,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:
7376

7477

7578
class SequencePattern(Pattern):
79+
"""The pattern [<pattern>, ...]"""
7680
patterns: List[Pattern]
7781

7882
def __init__(self, patterns: List[Pattern]):
@@ -114,6 +118,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:
114118

115119

116120
class ClassPattern(Pattern):
121+
"""The pattern Cls(...)"""
117122
class_ref: RefExpr
118123
positionals: List[Pattern]
119124
keyword_keys: List[str]

0 commit comments

Comments
 (0)