1
1
"""Expression type checker. This file is conceptually part of TypeChecker."""
2
2
3
3
from collections import OrderedDict
4
- from typing import cast , Dict , Set , List , Tuple , Callable , Union , Optional , Iterable , Sequence , Any
4
+ from contextlib import contextmanager
5
+ from typing import (
6
+ cast , Dict , Set , List , Tuple , Callable , Union , Optional , Iterable ,
7
+ Sequence , Any , Iterator
8
+ )
5
9
6
10
from mypy .errors import report_internal_error
7
11
from mypy .typeanal import (
58
62
ArgChecker = Callable [[Type , Type , int , Type , int , int , CallableType , Context , MessageBuilder ],
59
63
None ]
60
64
65
+ # Maximum nesting level for math union in overloads, setting this to large values
66
+ # may cause performance issues. The reason is that although union math algorithm we use
67
+ # nicely captures most corner cases, its worst case complexity is exponential,
68
+ # see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion.
69
+ MAX_UNIONS = 5
70
+
71
+
72
+ class TooManyUnions (Exception ):
73
+ """Indicates that we need to stop splitting unions in an attempt
74
+ to match an overload in order to save performance.
75
+ """
76
+
61
77
62
78
def extract_refexpr_names (expr : RefExpr ) -> Set [str ]:
63
79
"""Recursively extracts all module references from a reference expression.
@@ -120,6 +136,11 @@ def __init__(self,
120
136
self .msg = msg
121
137
self .plugin = plugin
122
138
self .type_context = [None ]
139
+ # Temporary overrides for expression types. This is currently
140
+ # used by the union math in overloads.
141
+ # TODO: refactor this to use a pattern similar to one in
142
+ # multiassign_from_union, or maybe even combine the two?
143
+ self .type_overrides = {} # type: Dict[Expression, Type]
123
144
self .strfrm_checker = StringFormatterChecker (self , self .chk , self .msg )
124
145
125
146
def visit_name_expr (self , e : NameExpr ) -> Type :
@@ -1138,41 +1159,46 @@ def check_overload_call(self,
1138
1159
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
1139
1160
erased_targets = None # type: Optional[List[CallableType]]
1140
1161
unioned_result = None # type: Optional[Tuple[Type, Type]]
1141
- unioned_errors = None # type: Optional[MessageBuilder]
1142
- union_success = False
1143
- if any (isinstance (arg , UnionType ) and len (arg .relevant_items ()) > 1 # "real" union
1144
- for arg in arg_types ):
1145
- erased_targets = self .overload_erased_call_targets (plausible_targets , arg_types ,
1146
- arg_kinds , arg_names , context )
1147
- unioned_callable = self .union_overload_matches (erased_targets )
1148
-
1149
- if unioned_callable is not None :
1150
- unioned_errors = arg_messages .clean_copy ()
1151
- unioned_result = self .check_call (unioned_callable , args , arg_kinds ,
1152
- context , arg_names ,
1153
- arg_messages = unioned_errors ,
1154
- callable_name = callable_name ,
1155
- object_type = object_type )
1162
+ union_interrupted = False # did we try all union combinations?
1163
+ if any (self .real_union (arg ) for arg in arg_types ):
1164
+ unioned_errors = arg_messages .clean_copy ()
1165
+ try :
1166
+ unioned_return = self .union_overload_result (plausible_targets , args ,
1167
+ arg_types , arg_kinds , arg_names ,
1168
+ callable_name , object_type ,
1169
+ context ,
1170
+ arg_messages = unioned_errors )
1171
+ except TooManyUnions :
1172
+ union_interrupted = True
1173
+ else :
1156
1174
# Record if we succeeded. Next we need to see if maybe normal procedure
1157
1175
# gives a narrower type.
1158
- union_success = unioned_result is not None and not unioned_errors .is_errors ()
1176
+ if unioned_return :
1177
+ returns , inferred_types = zip (* unioned_return )
1178
+ # Note that we use `union_overload_matches` instead of just returning
1179
+ # a union of inferred callables because for example a call
1180
+ # Union[int -> int, str -> str](Union[int, str]) is invalid and
1181
+ # we don't want to introduce internal inconsistencies.
1182
+ unioned_result = (UnionType .make_simplified_union (list (returns ),
1183
+ context .line ,
1184
+ context .column ),
1185
+ self .union_overload_matches (inferred_types ))
1159
1186
1160
1187
# Step 3: We try checking each branch one-by-one.
1161
1188
inferred_result = self .infer_overload_return_type (plausible_targets , args , arg_types ,
1162
1189
arg_kinds , arg_names , callable_name ,
1163
1190
object_type , context , arg_messages )
1164
- if inferred_result is not None :
1165
- # Success! Stop early by returning the best among normal and unioned.
1166
- if not union_success :
1191
+ # If any of checks succeed, stop early.
1192
+ if inferred_result is not None and unioned_result is not None :
1193
+ # Both unioned and direct checks succeeded, choose the more precise type.
1194
+ if (is_subtype (inferred_result [0 ], unioned_result [0 ]) and
1195
+ not isinstance (inferred_result [0 ], AnyType )):
1167
1196
return inferred_result
1168
- else :
1169
- assert unioned_result is not None
1170
- if is_subtype (inferred_result [0 ], unioned_result [0 ]):
1171
- return inferred_result
1172
- return unioned_result
1173
- elif union_success :
1174
- assert unioned_result is not None
1175
1197
return unioned_result
1198
+ elif unioned_result is not None :
1199
+ return unioned_result
1200
+ elif inferred_result is not None :
1201
+ return inferred_result
1176
1202
1177
1203
# Step 4: Failure. At this point, we know there is no match. We fall back to trying
1178
1204
# to find a somewhat plausible overload target using the erased types
@@ -1183,19 +1209,12 @@ def check_overload_call(self,
1183
1209
#
1184
1210
# Neither alternative matches, but we can guess the user probably wants the
1185
1211
# second one.
1186
- if erased_targets is None :
1187
- erased_targets = self .overload_erased_call_targets (plausible_targets , arg_types ,
1188
- arg_kinds , arg_names , context )
1212
+ erased_targets = self .overload_erased_call_targets (plausible_targets , arg_types ,
1213
+ arg_kinds , arg_names , context )
1189
1214
1190
1215
# Step 5: We try and infer a second-best alternative if possible. If not, fall back
1191
1216
# to using 'Any'.
1192
- if unioned_result is not None :
1193
- # When possible, return the error messages generated from the union-math attempt:
1194
- # they tend to be a little nicer.
1195
- assert unioned_errors is not None
1196
- arg_messages .add_errors (unioned_errors )
1197
- return unioned_result
1198
- elif len (erased_targets ) > 0 :
1217
+ if len (erased_targets ) > 0 :
1199
1218
# Pick the first plausible erased target as the fallback
1200
1219
# TODO: Adjust the error message here to make it clear there was no match.
1201
1220
target = erased_targets [0 ] # type: Type
@@ -1204,11 +1223,14 @@ def check_overload_call(self,
1204
1223
if not self .chk .should_suppress_optional_error (arg_types ):
1205
1224
arg_messages .no_variant_matches_arguments (callee , arg_types , context )
1206
1225
target = AnyType (TypeOfAny .from_error )
1207
-
1208
- return self .check_call (target , args , arg_kinds , context , arg_names ,
1209
- arg_messages = arg_messages ,
1210
- callable_name = callable_name ,
1211
- object_type = object_type )
1226
+ result = self .check_call (target , args , arg_kinds , context , arg_names ,
1227
+ arg_messages = arg_messages ,
1228
+ callable_name = callable_name ,
1229
+ object_type = object_type )
1230
+ if union_interrupted :
1231
+ self .chk .msg .note ("Not all union combinations were tried"
1232
+ " because there are too many unions" , context )
1233
+ return result
1212
1234
1213
1235
def plausible_overload_call_targets (self ,
1214
1236
arg_types : List [Type ],
@@ -1358,18 +1380,110 @@ def overload_erased_call_targets(self,
1358
1380
matches .append (typ )
1359
1381
return matches
1360
1382
1361
- def union_overload_matches (self , callables : List [CallableType ]) -> Optional [CallableType ]:
1383
+ def union_overload_result (self ,
1384
+ plausible_targets : List [CallableType ],
1385
+ args : List [Expression ],
1386
+ arg_types : List [Type ],
1387
+ arg_kinds : List [int ],
1388
+ arg_names : Optional [Sequence [Optional [str ]]],
1389
+ callable_name : Optional [str ],
1390
+ object_type : Optional [Type ],
1391
+ context : Context ,
1392
+ arg_messages : Optional [MessageBuilder ] = None ,
1393
+ level : int = 0
1394
+ ) -> Optional [List [Tuple [Type , Type ]]]:
1395
+ """Accepts a list of overload signatures and attempts to match calls by destructuring
1396
+ the first union.
1397
+
1398
+ Return a list of (<return type>, <inferred variant type>) if call succeeds for every
1399
+ item of the desctructured union. Returns None if there is no match.
1400
+ """
1401
+ # Step 1: If we are already too deep, then stop immediately. Otherwise mypy might
1402
+ # hang for long time because of a weird overload call. The caller will get
1403
+ # the exception and generate an appropriate note message, if needed.
1404
+ if level >= MAX_UNIONS :
1405
+ raise TooManyUnions
1406
+
1407
+ # Step 2: Find position of the first union in arguments. Return the normal inferred
1408
+ # type if no more unions left.
1409
+ for idx , typ in enumerate (arg_types ):
1410
+ if self .real_union (typ ):
1411
+ break
1412
+ else :
1413
+ # No unions in args, just fall back to normal inference
1414
+ with self .type_overrides_set (args , arg_types ):
1415
+ res = self .infer_overload_return_type (plausible_targets , args , arg_types ,
1416
+ arg_kinds , arg_names , callable_name ,
1417
+ object_type , context , arg_messages )
1418
+ if res is not None :
1419
+ return [res ]
1420
+ return None
1421
+
1422
+ # Step 3: Try a direct match before splitting to avoid unnecessary union splits
1423
+ # and save performance.
1424
+ with self .type_overrides_set (args , arg_types ):
1425
+ direct = self .infer_overload_return_type (plausible_targets , args , arg_types ,
1426
+ arg_kinds , arg_names , callable_name ,
1427
+ object_type , context , arg_messages )
1428
+ if direct is not None and not isinstance (direct [0 ], (UnionType , AnyType )):
1429
+ # We only return non-unions soon, to avoid greedy match.
1430
+ return [direct ]
1431
+
1432
+ # Step 4: Split the first remaining union type in arguments into items and
1433
+ # try to match each item individually (recursive).
1434
+ first_union = arg_types [idx ]
1435
+ assert isinstance (first_union , UnionType )
1436
+ res_items = []
1437
+ for item in first_union .relevant_items ():
1438
+ new_arg_types = arg_types .copy ()
1439
+ new_arg_types [idx ] = item
1440
+ sub_result = self .union_overload_result (plausible_targets , args , new_arg_types ,
1441
+ arg_kinds , arg_names , callable_name ,
1442
+ object_type , context , arg_messages ,
1443
+ level + 1 )
1444
+ if sub_result is not None :
1445
+ res_items .extend (sub_result )
1446
+ else :
1447
+ # Some item doesn't match, return soon.
1448
+ return None
1449
+
1450
+ # Step 5: If splitting succeeded, then filter out duplicate items before returning.
1451
+ seen = set () # type: Set[Tuple[Type, Type]]
1452
+ result = []
1453
+ for pair in res_items :
1454
+ if pair not in seen :
1455
+ seen .add (pair )
1456
+ result .append (pair )
1457
+ return result
1458
+
1459
+ def real_union (self , typ : Type ) -> bool :
1460
+ return isinstance (typ , UnionType ) and len (typ .relevant_items ()) > 1
1461
+
1462
+ @contextmanager
1463
+ def type_overrides_set (self , exprs : Sequence [Expression ],
1464
+ overrides : Sequence [Type ]) -> Iterator [None ]:
1465
+ """Set _temporary_ type overrides for given expressions."""
1466
+ assert len (exprs ) == len (overrides )
1467
+ for expr , typ in zip (exprs , overrides ):
1468
+ self .type_overrides [expr ] = typ
1469
+ try :
1470
+ yield
1471
+ finally :
1472
+ for expr in exprs :
1473
+ del self .type_overrides [expr ]
1474
+
1475
+ def union_overload_matches (self , types : Sequence [Type ]) -> Union [AnyType , CallableType ]:
1362
1476
"""Accepts a list of overload signatures and attempts to combine them together into a
1363
1477
new CallableType consisting of the union of all of the given arguments and return types.
1364
1478
1365
- Returns None if it is not possible to combine the different callables together in a
1366
- sound manner.
1367
-
1368
- Assumes all of the given callables have argument counts compatible with the caller.
1479
+ If there is at least one non-callable type, return Any (this can happen if there is
1480
+ an ambiguity because of Any in arguments).
1369
1481
"""
1370
- if len (callables ) == 0 :
1371
- return None
1372
- elif len (callables ) == 1 :
1482
+ assert types , "Trying to merge no callables"
1483
+ if not all (isinstance (c , CallableType ) for c in types ):
1484
+ return AnyType (TypeOfAny .special_form )
1485
+ callables = cast (List [CallableType ], types )
1486
+ if len (callables ) == 1 :
1373
1487
return callables [0 ]
1374
1488
1375
1489
# Note: we are assuming here that if a user uses some TypeVar 'T' in
@@ -1389,58 +1503,52 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
1389
1503
new_kinds = list (callables [0 ].arg_kinds )
1390
1504
new_returns = [] # type: List[Type]
1391
1505
1506
+ too_complex = False
1392
1507
for target in callables :
1393
- # We conservatively end if the overloads do not have the exact same signature.
1394
- # The only exception is if one arg is optional and the other is positional: in that
1395
- # case, we continue unioning (and expect a positional arg).
1396
- # TODO: Enhance the union overload logic to handle a wider variety of signatures.
1508
+ # We fall back to Callable[..., Union[<returns>]] if the overloads do not have
1509
+ # the exact same signature. The only exception is if one arg is optional and
1510
+ # the other is positional: in that case, we continue unioning (and expect a
1511
+ # positional arg).
1512
+ # TODO: Enhance the merging logic to handle a wider variety of signatures.
1397
1513
if len (new_kinds ) != len (target .arg_kinds ):
1398
- return None
1514
+ too_complex = True
1515
+ break
1399
1516
for i , (new_kind , target_kind ) in enumerate (zip (new_kinds , target .arg_kinds )):
1400
1517
if new_kind == target_kind :
1401
1518
continue
1402
1519
elif new_kind in (ARG_POS , ARG_OPT ) and target_kind in (ARG_POS , ARG_OPT ):
1403
1520
new_kinds [i ] = ARG_POS
1404
1521
else :
1405
- return None
1522
+ too_complex = True
1523
+ break
1524
+
1525
+ if too_complex :
1526
+ break # outer loop
1406
1527
1407
1528
for i , arg in enumerate (target .arg_types ):
1408
1529
new_args [i ].append (arg )
1409
1530
new_returns .append (target .ret_type )
1410
1531
1411
- union_count = 0
1532
+ union_return = UnionType .make_simplified_union (new_returns )
1533
+ if too_complex :
1534
+ any = AnyType (TypeOfAny .special_form )
1535
+ return callables [0 ].copy_modified (
1536
+ arg_types = [any , any ],
1537
+ arg_kinds = [ARG_STAR , ARG_STAR2 ],
1538
+ arg_names = [None , None ],
1539
+ ret_type = union_return ,
1540
+ variables = variables ,
1541
+ implicit = True )
1542
+
1412
1543
final_args = []
1413
1544
for args_list in new_args :
1414
1545
new_type = UnionType .make_simplified_union (args_list )
1415
- union_count += 1 if isinstance (new_type , UnionType ) else 0
1416
1546
final_args .append (new_type )
1417
1547
1418
- # TODO: Modify this check to be less conservative.
1419
- #
1420
- # Currently, we permit only one union in the arguments because if we allow
1421
- # multiple, we can't always guarantee the synthesized callable will be correct.
1422
- #
1423
- # For example, suppose we had the following two overloads:
1424
- #
1425
- # @overload
1426
- # def f(x: A, y: B) -> None: ...
1427
- # @overload
1428
- # def f(x: B, y: A) -> None: ...
1429
- #
1430
- # If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
1431
- # then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
1432
- # be rejected.
1433
- #
1434
- # However, that means we'll also give up if the original overloads contained
1435
- # any unions. This is likely unnecessary -- we only really need to give up if
1436
- # there are more then one *synthesized* union arguments.
1437
- if union_count >= 2 :
1438
- return None
1439
-
1440
1548
return callables [0 ].copy_modified (
1441
1549
arg_types = final_args ,
1442
1550
arg_kinds = new_kinds ,
1443
- ret_type = UnionType . make_simplified_union ( new_returns ) ,
1551
+ ret_type = union_return ,
1444
1552
variables = variables ,
1445
1553
implicit = True )
1446
1554
@@ -2733,6 +2841,8 @@ def accept(self,
2733
2841
is True and this expression is a call, allow it to return None. This
2734
2842
applies only to this expression and not any subexpressions.
2735
2843
"""
2844
+ if node in self .type_overrides :
2845
+ return self .type_overrides [node ]
2736
2846
self .type_context .append (type_context )
2737
2847
try :
2738
2848
if allow_none_return and isinstance (node , CallExpr ):
0 commit comments