Skip to content

Relax union math logic to allow signatures with different arg kinds #5222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE
)
from mypy.literals import literal
from mypy import nodes
Expand Down Expand Up @@ -1366,16 +1366,23 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
callables, variables = merge_typevars_in_callables_by_name(callables)

new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]]
new_kinds = list(callables[0].arg_kinds)
new_returns = [] # type: List[Type]

expected_names = callables[0].arg_names
expected_kinds = callables[0].arg_kinds

for target in callables:
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
# We conservatively end if the overloads do not have the exact same signature.
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
# We conservatively end if the overloads do not have the exact same signature.
# The only exception is if one arg is optional and the other is positional: in that
# case, we continue unioning (and expect a positional arg).
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
if len(new_kinds) != len(target.arg_kinds):
return None
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
if new_kind == target_kind:
continue
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
new_kinds[i] = ARG_POS
else:
return None

for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)
Expand All @@ -1390,7 +1397,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call

# TODO: Modify this check to be less conservative.
#
# Currently, we permit only one union union in the arguments because if we allow
# Currently, we permit only one union in the arguments because if we allow
# multiple, we can't always guarantee the synthesized callable will be correct.
#
# For example, suppose we had the following two overloads:
Expand All @@ -1412,6 +1419,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call

return callables[0].copy_modified(
arg_types=final_args,
arg_kinds=new_kinds,
ret_type=UnionType.make_simplified_union(new_returns),
variables=variables,
implicit=True)
Expand Down
60 changes: 48 additions & 12 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ reveal_type(foo(compat)) # E: Revealed type is 'Union[builtins.int, builtins.st
not_compat: Union[WrapperCo[A], WrapperContra[C]]
foo(not_compat) # E: Argument 1 to "foo" has incompatible type "Union[WrapperCo[A], WrapperContra[C]]"; expected "Union[WrapperCo[B], WrapperContra[B]]"

[case testOverloadInferUnionSkipIfParameterNamesAreDifferent]
[case testOverloadInferUnionIfParameterNamesAreDifferent]
from typing import overload, Union

class A: ...
Expand All @@ -2173,7 +2173,7 @@ def f(x): ...
x: Union[A, B]
reveal_type(f(A())) # E: Revealed type is '__main__.B'
reveal_type(f(B())) # E: Revealed type is '__main__.C'
f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A"
reveal_type(f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case name should be updated (remove "Skip").


[case testOverloadInferUnionReturnFunctionsWithKwargs]
from typing import overload, Union, Optional
Expand All @@ -2191,20 +2191,56 @@ def f(x: A, y: Optional[B] = None) -> C: ...
def f(x: A, z: Optional[C] = None) -> B: ...
def f(x, y=None, z=None): ...

reveal_type(f(A(), B()))
reveal_type(f(A(), C()))
reveal_type(f(A(), B())) # E: Revealed type is '__main__.C'
reveal_type(f(A(), C())) # E: Revealed type is '__main__.B'

arg: Union[B, C]
reveal_type(f(A(), arg))
reveal_type(f(A()))
reveal_type(f(A(), arg)) # E: Revealed type is 'Union[__main__.C, __main__.B]'
reveal_type(f(A())) # E: Revealed type is '__main__.D'

[builtins fixtures/tuple.pyi]
[out]
main:16: error: Revealed type is '__main__.C'
main:17: error: Revealed type is '__main__.B'
main:20: error: Revealed type is '__main__.C'
main:20: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]"
main:21: error: Revealed type is '__main__.D'

[case testOverloadInferUnionWithDifferingLengths]
from typing import overload, Union

class Parent: ...
class Child(Parent): ...

class A: ...
class B: ...

@overload
def f(x: A) -> Child: ...
@overload
def f(x: B, y: B = B()) -> Parent: ...
def f(*args): ...

# TODO: It would be nice if we could successfully do union math
# in this case. See comments in checkexpr.union_overload_matches.
x: Union[A, B]
f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A"
f(x, B()) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "B"

[case testOverloadInferUnionWithMixOfPositionalAndOptionalArgs]
# flags: --strict-optional
from typing import overload, Union, Optional

class A: ...
class B: ...

@overload
def f(x: A) -> int: ...
@overload
def f(x: Optional[B] = None) -> str: ...
def f(*args): ...

x: Union[A, B]
y: Optional[A]
z: Union[A, Optional[B]]
reveal_type(f(x)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
reveal_type(f(y)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
reveal_type(f(z)) # E: Revealed type is 'Union[builtins.int, builtins.str]'
reveal_type(f()) # E: Revealed type is 'builtins.str'

[case testOverloadingInferUnionReturnWithTypevarWithValueRestriction]
from typing import overload, Union, TypeVar, Generic
Expand Down