diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 09aca5fedf09..09232ba037d5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -22,7 +22,7 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, - ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE + ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, REVEAL_TYPE ) from mypy.literals import literal from mypy import nodes @@ -1366,16 +1366,23 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call callables, variables = merge_typevars_in_callables_by_name(callables) new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] + new_kinds = list(callables[0].arg_kinds) new_returns = [] # type: List[Type] - expected_names = callables[0].arg_names - expected_kinds = callables[0].arg_kinds - for target in callables: - if target.arg_names != expected_names or target.arg_kinds != expected_kinds: - # We conservatively end if the overloads do not have the exact same signature. - # TODO: Enhance the union overload logic to handle a wider variety of signatures. + # We conservatively end if the overloads do not have the exact same signature. + # The only exception is if one arg is optional and the other is positional: in that + # case, we continue unioning (and expect a positional arg). + # TODO: Enhance the union overload logic to handle a wider variety of signatures. + if len(new_kinds) != len(target.arg_kinds): return None + for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): + if new_kind == target_kind: + continue + elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): + new_kinds[i] = ARG_POS + else: + return None for i, arg in enumerate(target.arg_types): new_args[i].append(arg) @@ -1390,7 +1397,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call # TODO: Modify this check to be less conservative. # - # Currently, we permit only one union union in the arguments because if we allow + # Currently, we permit only one union in the arguments because if we allow # multiple, we can't always guarantee the synthesized callable will be correct. # # For example, suppose we had the following two overloads: @@ -1412,6 +1419,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call return callables[0].copy_modified( arg_types=final_args, + arg_kinds=new_kinds, ret_type=UnionType.make_simplified_union(new_returns), variables=variables, implicit=True) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index cf722b2f63d0..b1df8f7fbec8 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2157,7 +2157,7 @@ reveal_type(foo(compat)) # E: Revealed type is 'Union[builtins.int, builtins.st not_compat: Union[WrapperCo[A], WrapperContra[C]] foo(not_compat) # E: Argument 1 to "foo" has incompatible type "Union[WrapperCo[A], WrapperContra[C]]"; expected "Union[WrapperCo[B], WrapperContra[B]]" -[case testOverloadInferUnionSkipIfParameterNamesAreDifferent] +[case testOverloadInferUnionIfParameterNamesAreDifferent] from typing import overload, Union class A: ... @@ -2173,7 +2173,7 @@ def f(x): ... x: Union[A, B] reveal_type(f(A())) # E: Revealed type is '__main__.B' reveal_type(f(B())) # E: Revealed type is '__main__.C' -f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A" +reveal_type(f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]' [case testOverloadInferUnionReturnFunctionsWithKwargs] from typing import overload, Union, Optional @@ -2191,20 +2191,56 @@ def f(x: A, y: Optional[B] = None) -> C: ... def f(x: A, z: Optional[C] = None) -> B: ... def f(x, y=None, z=None): ... -reveal_type(f(A(), B())) -reveal_type(f(A(), C())) +reveal_type(f(A(), B())) # E: Revealed type is '__main__.C' +reveal_type(f(A(), C())) # E: Revealed type is '__main__.B' arg: Union[B, C] -reveal_type(f(A(), arg)) -reveal_type(f(A())) +reveal_type(f(A(), arg)) # E: Revealed type is 'Union[__main__.C, __main__.B]' +reveal_type(f(A())) # E: Revealed type is '__main__.D' [builtins fixtures/tuple.pyi] -[out] -main:16: error: Revealed type is '__main__.C' -main:17: error: Revealed type is '__main__.B' -main:20: error: Revealed type is '__main__.C' -main:20: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]" -main:21: error: Revealed type is '__main__.D' + +[case testOverloadInferUnionWithDifferingLengths] +from typing import overload, Union + +class Parent: ... +class Child(Parent): ... + +class A: ... +class B: ... + +@overload +def f(x: A) -> Child: ... +@overload +def f(x: B, y: B = B()) -> Parent: ... +def f(*args): ... + +# TODO: It would be nice if we could successfully do union math +# in this case. See comments in checkexpr.union_overload_matches. +x: Union[A, B] +f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A" +f(x, B()) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "B" + +[case testOverloadInferUnionWithMixOfPositionalAndOptionalArgs] +# flags: --strict-optional +from typing import overload, Union, Optional + +class A: ... +class B: ... + +@overload +def f(x: A) -> int: ... +@overload +def f(x: Optional[B] = None) -> str: ... +def f(*args): ... + +x: Union[A, B] +y: Optional[A] +z: Union[A, Optional[B]] +reveal_type(f(x)) # E: Revealed type is 'Union[builtins.int, builtins.str]' +reveal_type(f(y)) # E: Revealed type is 'Union[builtins.int, builtins.str]' +reveal_type(f(z)) # E: Revealed type is 'Union[builtins.int, builtins.str]' +reveal_type(f()) # E: Revealed type is 'builtins.str' [case testOverloadingInferUnionReturnWithTypevarWithValueRestriction] from typing import overload, Union, TypeVar, Generic