diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7abcf8feadc5..ae2a4907e0c2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1129,6 +1129,7 @@ def check_overload_call(self, erased_targets = None # type: Optional[List[CallableType]] unioned_result = None # type: Optional[Tuple[Type, Type]] unioned_errors = None # type: Optional[MessageBuilder] + union_success = False if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union for arg in arg_types): erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, @@ -1142,18 +1143,26 @@ def check_overload_call(self, arg_messages=unioned_errors, callable_name=callable_name, object_type=object_type) - if not unioned_errors.is_errors(): - # Success! Stop early. - return unioned_result + # Record if we succeeded. Next we need to see if maybe normal procedure + # gives a narrower type. + union_success = unioned_result is not None and not unioned_errors.is_errors() - # Step 3: If the union math fails, or if there was no union in the argument types, - # we fall back to checking each branch one-by-one. + # Step 3: We try checking each branch one-by-one. inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, arg_kinds, arg_names, callable_name, object_type, context, arg_messages) if inferred_result is not None: - # Success! Stop early. - return inferred_result + # Success! Stop early by returning the best among normal and unioned. + if not union_success: + return inferred_result + else: + assert unioned_result is not None + if is_subtype(inferred_result[0], unioned_result[0]): + return inferred_result + return unioned_result + elif union_success: + assert unioned_result is not None + return unioned_result # Step 4: Failure. At this point, we know there is no match. We fall back to trying # to find a somewhat plausible overload target using the erased types diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 54b01412d4ba..c74ec5e32958 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -3611,6 +3611,20 @@ class Wrapper: [builtins fixtures/staticmethod.pyi] +[case testUnionMathOverloadingReturnsBestType] +from typing import Union, overload + +@overload +def f(x: Union[int, str]) -> int: ... +@overload +def f(x: object) -> object: ... +def f(x): + pass + +x: Union[int, str] +reveal_type(f(x)) # E: Revealed type is 'builtins.int' +[out] + [case testOverloadAndSelfTypes] from typing import overload, Union, TypeVar, Type