@@ -4089,36 +4089,57 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
4089
4089
if isinstance (subject_type , DeletedType ):
4090
4090
self .msg .deleted_as_rvalue (subject_type , s )
4091
4091
4092
+ # We infer types of patterns twice. The first pass is used
4093
+ # to infer the types of capture variables. The type of a
4094
+ # capture variable may depend on multiple patterns (it
4095
+ # will be a union of all capture types). This pass ignores
4096
+ # guard expressions.
4092
4097
pattern_types = [self .pattern_checker .accept (p , subject_type ) for p in s .patterns ]
4093
-
4094
4098
type_maps : List [TypeMap ] = [t .captures for t in pattern_types ]
4095
- self .infer_variable_types_from_type_maps (type_maps )
4099
+ inferred_types = self .infer_variable_types_from_type_maps (type_maps )
4096
4100
4097
- for pattern_type , g , b in zip (pattern_types , s .guards , s .bodies ):
4101
+ # The second pass narrows down the types and type checks bodies.
4102
+ for p , g , b in zip (s .patterns , s .guards , s .bodies ):
4103
+ current_subject_type = self .expr_checker .narrow_type_from_binder (s .subject ,
4104
+ subject_type )
4105
+ pattern_type = self .pattern_checker .accept (p , current_subject_type )
4098
4106
with self .binder .frame_context (can_skip = True , fall_through = 2 ):
4099
4107
if b .is_unreachable or isinstance (get_proper_type (pattern_type .type ),
4100
4108
UninhabitedType ):
4101
4109
self .push_type_map (None )
4110
+ else_map : TypeMap = {}
4102
4111
else :
4103
- self .binder .put (s .subject , pattern_type .type )
4112
+ pattern_map , else_map = conditional_types_to_typemaps (
4113
+ s .subject ,
4114
+ pattern_type .type ,
4115
+ pattern_type .rest_type
4116
+ )
4117
+ self .remove_capture_conflicts (pattern_type .captures ,
4118
+ inferred_types )
4119
+ self .push_type_map (pattern_map )
4104
4120
self .push_type_map (pattern_type .captures )
4105
4121
if g is not None :
4106
- gt = get_proper_type (self .expr_checker .accept (g ))
4122
+ with self .binder .frame_context (can_skip = True , fall_through = 3 ):
4123
+ gt = get_proper_type (self .expr_checker .accept (g ))
4107
4124
4108
- if isinstance (gt , DeletedType ):
4109
- self .msg .deleted_as_rvalue (gt , s )
4125
+ if isinstance (gt , DeletedType ):
4126
+ self .msg .deleted_as_rvalue (gt , s )
4110
4127
4111
- if_map , _ = self .find_isinstance_check (g )
4128
+ guard_map , guard_else_map = self .find_isinstance_check (g )
4129
+ else_map = or_conditional_maps (else_map , guard_else_map )
4112
4130
4113
- self .push_type_map (if_map )
4114
- self .accept (b )
4131
+ self .push_type_map (guard_map )
4132
+ self .accept (b )
4133
+ else :
4134
+ self .accept (b )
4135
+ self .push_type_map (else_map )
4115
4136
4116
4137
# This is needed due to a quirk in frame_context. Without it types will stay narrowed
4117
4138
# after the match.
4118
4139
with self .binder .frame_context (can_skip = False , fall_through = 2 ):
4119
4140
pass
4120
4141
4121
- def infer_variable_types_from_type_maps (self , type_maps : List [TypeMap ]) -> None :
4142
+ def infer_variable_types_from_type_maps (self , type_maps : List [TypeMap ]) -> Dict [ Var , Type ] :
4122
4143
all_captures : Dict [Var , List [Tuple [NameExpr , Type ]]] = defaultdict (list )
4123
4144
for tm in type_maps :
4124
4145
if tm is not None :
@@ -4128,28 +4149,38 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
4128
4149
assert isinstance (node , Var )
4129
4150
all_captures [node ].append ((expr , typ ))
4130
4151
4152
+ inferred_types : Dict [Var , Type ] = {}
4131
4153
for var , captures in all_captures .items ():
4132
- conflict = False
4154
+ already_exists = False
4133
4155
types : List [Type ] = []
4134
4156
for expr , typ in captures :
4135
4157
types .append (typ )
4136
4158
4137
- previous_type , _ , inferred = self .check_lvalue (expr )
4159
+ previous_type , _ , _ = self .check_lvalue (expr )
4138
4160
if previous_type is not None :
4139
- conflict = True
4140
- self .check_subtype (typ , previous_type , expr ,
4141
- msg = message_registry .INCOMPATIBLE_TYPES_IN_CAPTURE ,
4142
- subtype_label = "pattern captures type" ,
4143
- supertype_label = "variable has type" )
4144
- for type_map in type_maps :
4145
- if type_map is not None and expr in type_map :
4146
- del type_map [expr ]
4147
-
4148
- if not conflict :
4161
+ already_exists = True
4162
+ if self .check_subtype (typ , previous_type , expr ,
4163
+ msg = message_registry .INCOMPATIBLE_TYPES_IN_CAPTURE ,
4164
+ subtype_label = "pattern captures type" ,
4165
+ supertype_label = "variable has type" ):
4166
+ inferred_types [var ] = previous_type
4167
+
4168
+ if not already_exists :
4149
4169
new_type = UnionType .make_union (types )
4150
4170
# Infer the union type at the first occurrence
4151
4171
first_occurrence , _ = captures [0 ]
4172
+ inferred_types [var ] = new_type
4152
4173
self .infer_variable_type (var , first_occurrence , new_type , first_occurrence )
4174
+ return inferred_types
4175
+
4176
+ def remove_capture_conflicts (self , type_map : TypeMap , inferred_types : Dict [Var , Type ]) -> None :
4177
+ if type_map :
4178
+ for expr , typ in list (type_map .items ()):
4179
+ if isinstance (expr , NameExpr ):
4180
+ node = expr .node
4181
+ assert isinstance (node , Var )
4182
+ if node not in inferred_types or not is_subtype (typ , inferred_types [node ]):
4183
+ del type_map [expr ]
4153
4184
4154
4185
def make_fake_typeinfo (self ,
4155
4186
curr_module_fullname : str ,
@@ -5637,6 +5668,14 @@ def conditional_types(current_type: Type,
5637
5668
None means no new information can be inferred. If default is set it is returned
5638
5669
instead."""
5639
5670
if proposed_type_ranges :
5671
+ if len (proposed_type_ranges ) == 1 :
5672
+ target = proposed_type_ranges [0 ].item
5673
+ target = get_proper_type (target )
5674
+ if isinstance (target , LiteralType ) and (target .is_enum_literal ()
5675
+ or isinstance (target .value , bool )):
5676
+ enum_name = target .fallback .type .fullname
5677
+ current_type = try_expanding_sum_type_to_union (current_type ,
5678
+ enum_name )
5640
5679
proposed_items = [type_range .item for type_range in proposed_type_ranges ]
5641
5680
proposed_type = make_simplified_union (proposed_items )
5642
5681
if isinstance (proposed_type , AnyType ):
0 commit comments