diff --git a/mypy/join.py b/mypy/join.py index 166434f58f8d..b12a6788246c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -402,10 +402,28 @@ def visit_instance(self, t: Instance) -> ProperType: return self.default(self.s) def visit_callable_type(self, t: CallableType) -> ProperType: - if isinstance(self.s, CallableType) and is_similar_callables(t, self.s): - if is_equivalent(t, self.s): - return combine_similar_callables(t, self.s) - result = join_similar_callables(t, self.s) + if isinstance(self.s, CallableType): + if is_similar_callables(t, self.s): + if is_equivalent(t, self.s): + return combine_similar_callables(t, self.s) + result = join_similar_callables(t, self.s) + if any( + isinstance(tp, (NoneType, UninhabitedType)) + for tp in get_proper_types(result.arg_types) + ): + # We don't want to return unusable Callable, attempt fallback instead. + return join_types(t.fallback, self.s) + else: + s, t = self.s, t + if t.is_var_arg: + s, t = t, s + if is_subtype(self.s, t): + result = t.copy_modified() + elif is_subtype(t, self.s): + result = self.s.copy_modified() + else: + return join_types(t.fallback, self.s) + # We set the from_type_type flag to suppress error when a collection of # concrete class objects gets inferred as their common abstract superclass. if not ( @@ -413,12 +431,6 @@ def visit_callable_type(self, t: CallableType) -> ProperType: or (self.s.is_type_obj() and self.s.type_object().is_abstract) ): result.from_type_type = True - if any( - isinstance(tp, (NoneType, UninhabitedType)) - for tp in get_proper_types(result.arg_types) - ): - # We don't want to return unusable Callable, attempt fallback instead. - return join_types(t.fallback, self.s) return result elif isinstance(self.s, Overloaded): # Switch the order of arguments to that we'll get to visit_overloaded. diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 18425efb9cb0..11974113e941 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3472,3 +3472,39 @@ class Qux(Bar): def baz(self, x) -> None: pass [builtins fixtures/tuple.pyi] + +[case testCallableJoinWithDefaults] +from typing import Callable, TypeVar + +T = TypeVar("T") + +def join(t1: T, t2: T) -> T: ... + +def f1() -> None: ... +def f2(i: int = 0) -> None: ... +def f3(i: str = "") -> None: ... + +reveal_type(join(f1, f2)) # N: Revealed type is "def ()" +reveal_type(join(f1, f3)) # N: Revealed type is "def ()" +reveal_type(join(f2, f3)) # N: Revealed type is "builtins.function" # TODO: this could be better +[builtins fixtures/tuple.pyi] + +[case testCallableJoinWithDefaultsMultiple] +from typing import TypeVar +T = TypeVar("T") +def join(t1: T, t2: T, t3: T) -> T: ... + +def f_1(common, a=None): ... +def f_any(*_, **__): ... +def f_3(common, b=None, x=None): ... + +fdict = { + "f_1": f_1, + "f_any": f_any, + "f_3": f_3, +} +reveal_type(fdict) # N: Revealed type is "builtins.dict[builtins.str, def (common: Any, a: Any =) -> Any]" + +reveal_type(join(f_1, f_any, f_3)) # N: Revealed type is "def (common: Any, a: Any =) -> Any" + +[builtins fixtures/tuple.pyi]