Skip to content

Commit e66d53b

Browse files
authored
Relax union math logic to allow signatures with different arg kinds (#5222)
Resolves #5204. This commit relaxes the union math logic so that it allows signatures with arg kinds that are nearly identical except that one arg is positional and the other is optional. This commit also removes the "names must be the same" restriction mostly so that the original example given in #4576 will pass. (In retrospect, I think this check didn't really buy us much -- even if the alternatives share the same arg names, there's no guarantee the actual implementation signature will also share the same.)
1 parent 4fe2220 commit e66d53b

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

mypy/checkexpr.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
2323
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
2424
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
25-
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE
25+
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE
2626
)
2727
from mypy.literals import literal
2828
from mypy import nodes
@@ -1366,16 +1366,23 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
13661366
callables, variables = merge_typevars_in_callables_by_name(callables)
13671367

13681368
new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]]
1369+
new_kinds = list(callables[0].arg_kinds)
13691370
new_returns = [] # type: List[Type]
13701371

1371-
expected_names = callables[0].arg_names
1372-
expected_kinds = callables[0].arg_kinds
1373-
13741372
for target in callables:
1375-
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
1376-
# We conservatively end if the overloads do not have the exact same signature.
1377-
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
1373+
# We conservatively end if the overloads do not have the exact same signature.
1374+
# The only exception is if one arg is optional and the other is positional: in that
1375+
# case, we continue unioning (and expect a positional arg).
1376+
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
1377+
if len(new_kinds) != len(target.arg_kinds):
13781378
return None
1379+
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
1380+
if new_kind == target_kind:
1381+
continue
1382+
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
1383+
new_kinds[i] = ARG_POS
1384+
else:
1385+
return None
13791386

13801387
for i, arg in enumerate(target.arg_types):
13811388
new_args[i].append(arg)
@@ -1390,7 +1397,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
13901397

13911398
# TODO: Modify this check to be less conservative.
13921399
#
1393-
# Currently, we permit only one union union in the arguments because if we allow
1400+
# Currently, we permit only one union in the arguments because if we allow
13941401
# multiple, we can't always guarantee the synthesized callable will be correct.
13951402
#
13961403
# For example, suppose we had the following two overloads:
@@ -1412,6 +1419,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
14121419

14131420
return callables[0].copy_modified(
14141421
arg_types=final_args,
1422+
arg_kinds=new_kinds,
14151423
ret_type=UnionType.make_simplified_union(new_returns),
14161424
variables=variables,
14171425
implicit=True)

test-data/unit/check-overloading.test

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,7 +2382,7 @@ reveal_type(foo(compat)) # E: Revealed type is 'Union[builtins.int, builtins.st
23822382
not_compat: Union[WrapperCo[A], WrapperContra[C]]
23832383
foo(not_compat) # E: Argument 1 to "foo" has incompatible type "Union[WrapperCo[A], WrapperContra[C]]"; expected "Union[WrapperCo[B], WrapperContra[B]]"
23842384

2385-
[case testOverloadInferUnionSkipIfParameterNamesAreDifferent]
2385+
[case testOverloadInferUnionIfParameterNamesAreDifferent]
23862386
from typing import overload, Union
23872387

23882388
class A: ...
@@ -2398,7 +2398,7 @@ def f(x): ...
23982398
x: Union[A, B]
23992399
reveal_type(f(A())) # E: Revealed type is '__main__.B'
24002400
reveal_type(f(B())) # E: Revealed type is '__main__.C'
2401-
f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A"
2401+
reveal_type(f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]'
24022402

24032403
[case testOverloadInferUnionReturnFunctionsWithKwargs]
24042404
from typing import overload, Union, Optional
@@ -2416,20 +2416,56 @@ def f(x: A, y: Optional[B] = None) -> C: ...
24162416
def f(x: A, z: Optional[C] = None) -> B: ...
24172417
def f(x, y=None, z=None): ...
24182418

2419-
reveal_type(f(A(), B()))
2420-
reveal_type(f(A(), C()))
2419+
reveal_type(f(A(), B())) # E: Revealed type is '__main__.C'
2420+
reveal_type(f(A(), C())) # E: Revealed type is '__main__.B'
24212421

24222422
arg: Union[B, C]
2423-
reveal_type(f(A(), arg))
2424-
reveal_type(f(A()))
2423+
reveal_type(f(A(), arg)) # E: Revealed type is 'Union[__main__.C, __main__.B]'
2424+
reveal_type(f(A())) # E: Revealed type is '__main__.D'
24252425

24262426
[builtins fixtures/tuple.pyi]
2427-
[out]
2428-
main:16: error: Revealed type is '__main__.C'
2429-
main:17: error: Revealed type is '__main__.B'
2430-
main:20: error: Revealed type is '__main__.C'
2431-
main:20: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]"
2432-
main:21: error: Revealed type is '__main__.D'
2427+
2428+
[case testOverloadInferUnionWithDifferingLengths]
2429+
from typing import overload, Union
2430+
2431+
class Parent: ...
2432+
class Child(Parent): ...
2433+
2434+
class A: ...
2435+
class B: ...
2436+
2437+
@overload
2438+
def f(x: A) -> Child: ...
2439+
@overload
2440+
def f(x: B, y: B = B()) -> Parent: ...
2441+
def f(*args): ...
2442+
2443+
# TODO: It would be nice if we could successfully do union math
2444+
# in this case. See comments in checkexpr.union_overload_matches.
2445+
x: Union[A, B]
2446+
f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A"
2447+
f(x, B()) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "B"
2448+
2449+
[case testOverloadInferUnionWithMixOfPositionalAndOptionalArgs]
2450+
# flags: --strict-optional
2451+
from typing import overload, Union, Optional
2452+
2453+
class A: ...
2454+
class B: ...
2455+
2456+
@overload
2457+
def f(x: A) -> int: ...
2458+
@overload
2459+
def f(x: Optional[B] = None) -> str: ...
2460+
def f(*args): ...
2461+
2462+
x: Union[A, B]
2463+
y: Optional[A]
2464+
z: Union[A, Optional[B]]
2465+
reveal_type(f(x)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
2466+
reveal_type(f(y)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
2467+
reveal_type(f(z)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
2468+
reveal_type(f()) # E: Revealed type is 'builtins.str'
24332469

24342470
[case testOverloadingInferUnionReturnWithTypevarWithValueRestriction]
24352471
from typing import overload, Union, TypeVar, Generic

0 commit comments

Comments
 (0)