Skip to content

Commit 0ca6bf9

Browse files
authored
Alternative algorithm for union math (#5255)
Fixes #5243 Fixes #5249 Some comments: * I went ahead with a slow but very simple recursive algorithm that treats all various complex cases correctly. On one hand it can be exponential, but on the other hand, the complexity will be bad _only_ if user abuses lots of unions * I use a hack caused by the fact that currently most function inference functions pass argument _expressions_ instead of types, I left a TODO to use a more unified approach similar to multiassign_from_union * It may look like there are many changes in tests, but actually there are not, the differences are because: - Error messages now show the _first potentially matching_ overload (which is OK I think) - Order of items in many unions turned to the opposite, apparently union `__repr__` is unstable.
1 parent 44e789d commit 0ca6bf9

File tree

3 files changed

+326
-109
lines changed

3 files changed

+326
-109
lines changed

mypy/checkexpr.py

+191-81
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Expression type checker. This file is conceptually part of TypeChecker."""
22

33
from collections import OrderedDict
4-
from typing import cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, Sequence, Any
4+
from contextlib import contextmanager
5+
from typing import (
6+
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable,
7+
Sequence, Any, Iterator
8+
)
59

610
from mypy.errors import report_internal_error
711
from mypy.typeanal import (
@@ -58,6 +62,18 @@
5862
ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder],
5963
None]
6064

65+
# Maximum nesting level for math union in overloads, setting this to large values
66+
# may cause performance issues. The reason is that although union math algorithm we use
67+
# nicely captures most corner cases, its worst case complexity is exponential,
68+
# see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion.
69+
MAX_UNIONS = 5
70+
71+
72+
class TooManyUnions(Exception):
73+
"""Indicates that we need to stop splitting unions in an attempt
74+
to match an overload in order to save performance.
75+
"""
76+
6177

6278
def extract_refexpr_names(expr: RefExpr) -> Set[str]:
6379
"""Recursively extracts all module references from a reference expression.
@@ -120,6 +136,11 @@ def __init__(self,
120136
self.msg = msg
121137
self.plugin = plugin
122138
self.type_context = [None]
139+
# Temporary overrides for expression types. This is currently
140+
# used by the union math in overloads.
141+
# TODO: refactor this to use a pattern similar to one in
142+
# multiassign_from_union, or maybe even combine the two?
143+
self.type_overrides = {} # type: Dict[Expression, Type]
123144
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
124145

125146
def visit_name_expr(self, e: NameExpr) -> Type:
@@ -1138,41 +1159,46 @@ def check_overload_call(self,
11381159
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
11391160
erased_targets = None # type: Optional[List[CallableType]]
11401161
unioned_result = None # type: Optional[Tuple[Type, Type]]
1141-
unioned_errors = None # type: Optional[MessageBuilder]
1142-
union_success = False
1143-
if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union
1144-
for arg in arg_types):
1145-
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1146-
arg_kinds, arg_names, context)
1147-
unioned_callable = self.union_overload_matches(erased_targets)
1148-
1149-
if unioned_callable is not None:
1150-
unioned_errors = arg_messages.clean_copy()
1151-
unioned_result = self.check_call(unioned_callable, args, arg_kinds,
1152-
context, arg_names,
1153-
arg_messages=unioned_errors,
1154-
callable_name=callable_name,
1155-
object_type=object_type)
1162+
union_interrupted = False # did we try all union combinations?
1163+
if any(self.real_union(arg) for arg in arg_types):
1164+
unioned_errors = arg_messages.clean_copy()
1165+
try:
1166+
unioned_return = self.union_overload_result(plausible_targets, args,
1167+
arg_types, arg_kinds, arg_names,
1168+
callable_name, object_type,
1169+
context,
1170+
arg_messages=unioned_errors)
1171+
except TooManyUnions:
1172+
union_interrupted = True
1173+
else:
11561174
# Record if we succeeded. Next we need to see if maybe normal procedure
11571175
# gives a narrower type.
1158-
union_success = unioned_result is not None and not unioned_errors.is_errors()
1176+
if unioned_return:
1177+
returns, inferred_types = zip(*unioned_return)
1178+
# Note that we use `union_overload_matches` instead of just returning
1179+
# a union of inferred callables because for example a call
1180+
# Union[int -> int, str -> str](Union[int, str]) is invalid and
1181+
# we don't want to introduce internal inconsistencies.
1182+
unioned_result = (UnionType.make_simplified_union(list(returns),
1183+
context.line,
1184+
context.column),
1185+
self.union_overload_matches(inferred_types))
11591186

11601187
# Step 3: We try checking each branch one-by-one.
11611188
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
11621189
arg_kinds, arg_names, callable_name,
11631190
object_type, context, arg_messages)
1164-
if inferred_result is not None:
1165-
# Success! Stop early by returning the best among normal and unioned.
1166-
if not union_success:
1191+
# If any of checks succeed, stop early.
1192+
if inferred_result is not None and unioned_result is not None:
1193+
# Both unioned and direct checks succeeded, choose the more precise type.
1194+
if (is_subtype(inferred_result[0], unioned_result[0]) and
1195+
not isinstance(inferred_result[0], AnyType)):
11671196
return inferred_result
1168-
else:
1169-
assert unioned_result is not None
1170-
if is_subtype(inferred_result[0], unioned_result[0]):
1171-
return inferred_result
1172-
return unioned_result
1173-
elif union_success:
1174-
assert unioned_result is not None
11751197
return unioned_result
1198+
elif unioned_result is not None:
1199+
return unioned_result
1200+
elif inferred_result is not None:
1201+
return inferred_result
11761202

11771203
# Step 4: Failure. At this point, we know there is no match. We fall back to trying
11781204
# to find a somewhat plausible overload target using the erased types
@@ -1183,19 +1209,12 @@ def check_overload_call(self,
11831209
#
11841210
# Neither alternative matches, but we can guess the user probably wants the
11851211
# second one.
1186-
if erased_targets is None:
1187-
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1188-
arg_kinds, arg_names, context)
1212+
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1213+
arg_kinds, arg_names, context)
11891214

11901215
# Step 5: We try and infer a second-best alternative if possible. If not, fall back
11911216
# to using 'Any'.
1192-
if unioned_result is not None:
1193-
# When possible, return the error messages generated from the union-math attempt:
1194-
# they tend to be a little nicer.
1195-
assert unioned_errors is not None
1196-
arg_messages.add_errors(unioned_errors)
1197-
return unioned_result
1198-
elif len(erased_targets) > 0:
1217+
if len(erased_targets) > 0:
11991218
# Pick the first plausible erased target as the fallback
12001219
# TODO: Adjust the error message here to make it clear there was no match.
12011220
target = erased_targets[0] # type: Type
@@ -1204,11 +1223,14 @@ def check_overload_call(self,
12041223
if not self.chk.should_suppress_optional_error(arg_types):
12051224
arg_messages.no_variant_matches_arguments(callee, arg_types, context)
12061225
target = AnyType(TypeOfAny.from_error)
1207-
1208-
return self.check_call(target, args, arg_kinds, context, arg_names,
1209-
arg_messages=arg_messages,
1210-
callable_name=callable_name,
1211-
object_type=object_type)
1226+
result = self.check_call(target, args, arg_kinds, context, arg_names,
1227+
arg_messages=arg_messages,
1228+
callable_name=callable_name,
1229+
object_type=object_type)
1230+
if union_interrupted:
1231+
self.chk.msg.note("Not all union combinations were tried"
1232+
" because there are too many unions", context)
1233+
return result
12121234

12131235
def plausible_overload_call_targets(self,
12141236
arg_types: List[Type],
@@ -1358,18 +1380,110 @@ def overload_erased_call_targets(self,
13581380
matches.append(typ)
13591381
return matches
13601382

1361-
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
1383+
def union_overload_result(self,
1384+
plausible_targets: List[CallableType],
1385+
args: List[Expression],
1386+
arg_types: List[Type],
1387+
arg_kinds: List[int],
1388+
arg_names: Optional[Sequence[Optional[str]]],
1389+
callable_name: Optional[str],
1390+
object_type: Optional[Type],
1391+
context: Context,
1392+
arg_messages: Optional[MessageBuilder] = None,
1393+
level: int = 0
1394+
) -> Optional[List[Tuple[Type, Type]]]:
1395+
"""Accepts a list of overload signatures and attempts to match calls by destructuring
1396+
the first union.
1397+
1398+
Return a list of (<return type>, <inferred variant type>) if call succeeds for every
1399+
item of the desctructured union. Returns None if there is no match.
1400+
"""
1401+
# Step 1: If we are already too deep, then stop immediately. Otherwise mypy might
1402+
# hang for long time because of a weird overload call. The caller will get
1403+
# the exception and generate an appropriate note message, if needed.
1404+
if level >= MAX_UNIONS:
1405+
raise TooManyUnions
1406+
1407+
# Step 2: Find position of the first union in arguments. Return the normal inferred
1408+
# type if no more unions left.
1409+
for idx, typ in enumerate(arg_types):
1410+
if self.real_union(typ):
1411+
break
1412+
else:
1413+
# No unions in args, just fall back to normal inference
1414+
with self.type_overrides_set(args, arg_types):
1415+
res = self.infer_overload_return_type(plausible_targets, args, arg_types,
1416+
arg_kinds, arg_names, callable_name,
1417+
object_type, context, arg_messages)
1418+
if res is not None:
1419+
return [res]
1420+
return None
1421+
1422+
# Step 3: Try a direct match before splitting to avoid unnecessary union splits
1423+
# and save performance.
1424+
with self.type_overrides_set(args, arg_types):
1425+
direct = self.infer_overload_return_type(plausible_targets, args, arg_types,
1426+
arg_kinds, arg_names, callable_name,
1427+
object_type, context, arg_messages)
1428+
if direct is not None and not isinstance(direct[0], (UnionType, AnyType)):
1429+
# We only return non-unions soon, to avoid greedy match.
1430+
return [direct]
1431+
1432+
# Step 4: Split the first remaining union type in arguments into items and
1433+
# try to match each item individually (recursive).
1434+
first_union = arg_types[idx]
1435+
assert isinstance(first_union, UnionType)
1436+
res_items = []
1437+
for item in first_union.relevant_items():
1438+
new_arg_types = arg_types.copy()
1439+
new_arg_types[idx] = item
1440+
sub_result = self.union_overload_result(plausible_targets, args, new_arg_types,
1441+
arg_kinds, arg_names, callable_name,
1442+
object_type, context, arg_messages,
1443+
level + 1)
1444+
if sub_result is not None:
1445+
res_items.extend(sub_result)
1446+
else:
1447+
# Some item doesn't match, return soon.
1448+
return None
1449+
1450+
# Step 5: If splitting succeeded, then filter out duplicate items before returning.
1451+
seen = set() # type: Set[Tuple[Type, Type]]
1452+
result = []
1453+
for pair in res_items:
1454+
if pair not in seen:
1455+
seen.add(pair)
1456+
result.append(pair)
1457+
return result
1458+
1459+
def real_union(self, typ: Type) -> bool:
1460+
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1
1461+
1462+
@contextmanager
1463+
def type_overrides_set(self, exprs: Sequence[Expression],
1464+
overrides: Sequence[Type]) -> Iterator[None]:
1465+
"""Set _temporary_ type overrides for given expressions."""
1466+
assert len(exprs) == len(overrides)
1467+
for expr, typ in zip(exprs, overrides):
1468+
self.type_overrides[expr] = typ
1469+
try:
1470+
yield
1471+
finally:
1472+
for expr in exprs:
1473+
del self.type_overrides[expr]
1474+
1475+
def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
13621476
"""Accepts a list of overload signatures and attempts to combine them together into a
13631477
new CallableType consisting of the union of all of the given arguments and return types.
13641478
1365-
Returns None if it is not possible to combine the different callables together in a
1366-
sound manner.
1367-
1368-
Assumes all of the given callables have argument counts compatible with the caller.
1479+
If there is at least one non-callable type, return Any (this can happen if there is
1480+
an ambiguity because of Any in arguments).
13691481
"""
1370-
if len(callables) == 0:
1371-
return None
1372-
elif len(callables) == 1:
1482+
assert types, "Trying to merge no callables"
1483+
if not all(isinstance(c, CallableType) for c in types):
1484+
return AnyType(TypeOfAny.special_form)
1485+
callables = cast(List[CallableType], types)
1486+
if len(callables) == 1:
13731487
return callables[0]
13741488

13751489
# Note: we are assuming here that if a user uses some TypeVar 'T' in
@@ -1389,58 +1503,52 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
13891503
new_kinds = list(callables[0].arg_kinds)
13901504
new_returns = [] # type: List[Type]
13911505

1506+
too_complex = False
13921507
for target in callables:
1393-
# We conservatively end if the overloads do not have the exact same signature.
1394-
# The only exception is if one arg is optional and the other is positional: in that
1395-
# case, we continue unioning (and expect a positional arg).
1396-
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
1508+
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
1509+
# the exact same signature. The only exception is if one arg is optional and
1510+
# the other is positional: in that case, we continue unioning (and expect a
1511+
# positional arg).
1512+
# TODO: Enhance the merging logic to handle a wider variety of signatures.
13971513
if len(new_kinds) != len(target.arg_kinds):
1398-
return None
1514+
too_complex = True
1515+
break
13991516
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
14001517
if new_kind == target_kind:
14011518
continue
14021519
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
14031520
new_kinds[i] = ARG_POS
14041521
else:
1405-
return None
1522+
too_complex = True
1523+
break
1524+
1525+
if too_complex:
1526+
break # outer loop
14061527

14071528
for i, arg in enumerate(target.arg_types):
14081529
new_args[i].append(arg)
14091530
new_returns.append(target.ret_type)
14101531

1411-
union_count = 0
1532+
union_return = UnionType.make_simplified_union(new_returns)
1533+
if too_complex:
1534+
any = AnyType(TypeOfAny.special_form)
1535+
return callables[0].copy_modified(
1536+
arg_types=[any, any],
1537+
arg_kinds=[ARG_STAR, ARG_STAR2],
1538+
arg_names=[None, None],
1539+
ret_type=union_return,
1540+
variables=variables,
1541+
implicit=True)
1542+
14121543
final_args = []
14131544
for args_list in new_args:
14141545
new_type = UnionType.make_simplified_union(args_list)
1415-
union_count += 1 if isinstance(new_type, UnionType) else 0
14161546
final_args.append(new_type)
14171547

1418-
# TODO: Modify this check to be less conservative.
1419-
#
1420-
# Currently, we permit only one union in the arguments because if we allow
1421-
# multiple, we can't always guarantee the synthesized callable will be correct.
1422-
#
1423-
# For example, suppose we had the following two overloads:
1424-
#
1425-
# @overload
1426-
# def f(x: A, y: B) -> None: ...
1427-
# @overload
1428-
# def f(x: B, y: A) -> None: ...
1429-
#
1430-
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
1431-
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
1432-
# be rejected.
1433-
#
1434-
# However, that means we'll also give up if the original overloads contained
1435-
# any unions. This is likely unnecessary -- we only really need to give up if
1436-
# there are more then one *synthesized* union arguments.
1437-
if union_count >= 2:
1438-
return None
1439-
14401548
return callables[0].copy_modified(
14411549
arg_types=final_args,
14421550
arg_kinds=new_kinds,
1443-
ret_type=UnionType.make_simplified_union(new_returns),
1551+
ret_type=union_return,
14441552
variables=variables,
14451553
implicit=True)
14461554

@@ -2733,6 +2841,8 @@ def accept(self,
27332841
is True and this expression is a call, allow it to return None. This
27342842
applies only to this expression and not any subexpressions.
27352843
"""
2844+
if node in self.type_overrides:
2845+
return self.type_overrides[node]
27362846
self.type_context.append(type_context)
27372847
try:
27382848
if allow_none_return and isinstance(node, CallExpr):

0 commit comments

Comments
 (0)