Skip to content

Commit 00359ad

Browse files
ilevkivskyiJukkaL
authored andcommitted
Improve join and meet of callables and overloads (#2833)
Fixes #1983 Here I implement: * join(Callable[[A1], R1]), Callable[[A2], R2]) == Callable[[meet(A1, A2)], join(R1, R2)] * meet(Callable[[A1], R1]), Callable[[A2], R2]) == Callable[[join(A1, A2)], meet(R1, R2)] plus special cases for Any, overloads, and callable type objects. The meet and join are still not perfect, but I think this PR improves the situation.
1 parent 62d4bc8 commit 00359ad

File tree

5 files changed

+120
-16
lines changed

5 files changed

+120
-16
lines changed

mypy/checkexpr.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2435,9 +2435,12 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
24352435
(isinstance(actual, Instance) and actual.type.fallback_to_any)):
24362436
# These could match anything at runtime.
24372437
return 2
2438-
if isinstance(formal, CallableType) and isinstance(actual, (CallableType, Overloaded)):
2439-
# TODO: do more sophisticated callable matching
2440-
return 2
2438+
if isinstance(formal, CallableType):
2439+
if isinstance(actual, (CallableType, Overloaded)):
2440+
# TODO: do more sophisticated callable matching
2441+
return 2
2442+
if isinstance(actual, TypeType):
2443+
return 2 if is_subtype(actual, formal) else 0
24412444
if isinstance(actual, NoneTyp):
24422445
if not experiments.STRICT_OPTIONAL:
24432446
# NoneTyp matches anything if we're not doing strict Optional checking

mypy/join.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,14 @@ def visit_instance(self, t: Instance) -> Type:
152152
return self.default(self.s)
153153

154154
def visit_callable_type(self, t: CallableType) -> Type:
155-
# TODO: Consider subtyping instead of just similarity.
156155
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
157-
return combine_similar_callables(t, self.s)
156+
if is_equivalent(t, self.s):
157+
return combine_similar_callables(t, self.s)
158+
result = join_similar_callables(t, self.s)
159+
if any(isinstance(tp, (NoneTyp, UninhabitedType)) for tp in result.arg_types):
160+
# We don't want to return unusable Callable, attempt fallback instead.
161+
return join_types(t.fallback, self.s)
162+
return result
158163
elif isinstance(self.s, Overloaded):
159164
# Switch the order of arguments to that we'll get to visit_overloaded.
160165
return join_types(t, self.s)
@@ -189,15 +194,18 @@ def visit_overloaded(self, t: Overloaded) -> Type:
189194
# join(Ov([int, Any] -> Any, [str, Any] -> Any), [Any, int] -> Any) ==
190195
# Ov([Any, int] -> Any, [Any, int] -> Any)
191196
#
192-
# TODO: Use callable subtyping instead of just similarity.
197+
# TODO: Consider more cases of callable subtyping.
193198
result = [] # type: List[CallableType]
194199
s = self.s
195200
if isinstance(s, FunctionLike):
196201
# The interesting case where both types are function types.
197202
for t_item in t.items():
198203
for s_item in s.items():
199204
if is_similar_callables(t_item, s_item):
200-
result.append(combine_similar_callables(t_item, s_item))
205+
if is_equivalent(t_item, s_item):
206+
result.append(combine_similar_callables(t_item, s_item))
207+
elif is_subtype(t_item, s_item):
208+
result.append(s_item)
201209
if result:
202210
# TODO: Simplify redundancies from the result.
203211
if len(result) == 1:
@@ -323,12 +331,30 @@ def is_better(t: Type, s: Type) -> bool:
323331

324332

325333
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
326-
"""Return True if t and s are equivalent and have identical numbers of
334+
"""Return True if t and s have identical numbers of
327335
arguments, default arguments and varargs.
328336
"""
329337

330-
return (len(t.arg_types) == len(s.arg_types) and t.min_args == s.min_args
331-
and t.is_var_arg == s.is_var_arg and is_equivalent(t, s))
338+
return (len(t.arg_types) == len(s.arg_types) and t.min_args == s.min_args and
339+
t.is_var_arg == s.is_var_arg)
340+
341+
342+
def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
343+
from mypy.meet import meet_types
344+
arg_types = [] # type: List[Type]
345+
for i in range(len(t.arg_types)):
346+
arg_types.append(meet_types(t.arg_types[i], s.arg_types[i]))
347+
# TODO in combine_similar_callables also applies here (names and kinds)
348+
# The fallback type can be either 'function' or 'type'. The result should have 'type' as
349+
# fallback only if both operands have it as 'type'.
350+
if t.fallback.type.fullname() != 'builtins.type':
351+
fallback = t.fallback
352+
else:
353+
fallback = s.fallback
354+
return t.copy_modified(arg_types=arg_types,
355+
ret_type=join_types(t.ret_type, s.ret_type),
356+
fallback=fallback,
357+
name=None)
332358

333359

334360
def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:

mypy/meet.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,13 @@ def visit_instance(self, t: Instance) -> Type:
220220

221221
def visit_callable_type(self, t: CallableType) -> Type:
222222
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
223-
return combine_similar_callables(t, self.s)
223+
if is_equivalent(t, self.s):
224+
return combine_similar_callables(t, self.s)
225+
result = meet_similar_callables(t, self.s)
226+
if isinstance(result.ret_type, UninhabitedType):
227+
# Return a plain None or <uninhabited> instead of a weird function.
228+
return self.default(self.s)
229+
return result
224230
else:
225231
return self.default(self.s)
226232

@@ -279,3 +285,21 @@ def default(self, typ: Type) -> Type:
279285
return UninhabitedType()
280286
else:
281287
return NoneTyp()
288+
289+
290+
def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
291+
from mypy.join import join_types
292+
arg_types = [] # type: List[Type]
293+
for i in range(len(t.arg_types)):
294+
arg_types.append(join_types(t.arg_types[i], s.arg_types[i]))
295+
# TODO in combine_similar_callables also applies here (names and kinds)
296+
# The fallback type can be either 'function' or 'type'. The result should have 'function' as
297+
# fallback only if both operands have it as 'function'.
298+
if t.fallback.type.fullname() != 'builtins.function':
299+
fallback = t.fallback
300+
else:
301+
fallback = s.fallback
302+
return t.copy_modified(arg_types=arg_types,
303+
ret_type=meet_types(t.ret_type, s.ret_type),
304+
fallback=fallback,
305+
name=None)

mypy/test/testtypes.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,16 @@ def test_function_types(self) -> None:
392392

393393
self.assert_join(self.callable(self.fx.a, self.fx.b),
394394
self.callable(self.fx.b, self.fx.b),
395-
self.fx.function)
395+
self.callable(self.fx.b, self.fx.b))
396396
self.assert_join(self.callable(self.fx.a, self.fx.b),
397397
self.callable(self.fx.a, self.fx.a),
398-
self.fx.function)
398+
self.callable(self.fx.a, self.fx.a))
399399
self.assert_join(self.callable(self.fx.a, self.fx.b),
400400
self.fx.function,
401401
self.fx.function)
402+
self.assert_join(self.callable(self.fx.a, self.fx.b),
403+
self.callable(self.fx.d, self.fx.b),
404+
self.fx.function)
402405

403406
def test_type_vars(self) -> None:
404407
self.assert_join(self.fx.t, self.fx.t, self.fx.t)
@@ -560,13 +563,14 @@ def test_generic_interfaces(self) -> None:
560563
def test_simple_type_objects(self) -> None:
561564
t1 = self.type_callable(self.fx.a, self.fx.a)
562565
t2 = self.type_callable(self.fx.b, self.fx.b)
566+
tr = self.type_callable(self.fx.b, self.fx.a)
563567

564568
self.assert_join(t1, t1, t1)
565569
j = join_types(t1, t1)
566570
assert isinstance(j, CallableType)
567571
assert_true(j.is_type_obj())
568572

569-
self.assert_join(t1, t2, self.fx.type_type)
573+
self.assert_join(t1, t2, tr)
570574
self.assert_join(t1, self.fx.type_type, self.fx.type_type)
571575
self.assert_join(self.fx.type_type, self.fx.type_type,
572576
self.fx.type_type)
@@ -658,10 +662,10 @@ def test_function_types(self) -> None:
658662

659663
self.assert_meet(self.callable(self.fx.a, self.fx.b),
660664
self.callable(self.fx.b, self.fx.b),
661-
NoneTyp())
665+
self.callable(self.fx.a, self.fx.b))
662666
self.assert_meet(self.callable(self.fx.a, self.fx.b),
663667
self.callable(self.fx.a, self.fx.a),
664-
NoneTyp())
668+
self.callable(self.fx.a, self.fx.b))
665669

666670
def test_type_vars(self) -> None:
667671
self.assert_meet(self.fx.t, self.fx.t, self.fx.t)

test-data/unit/check-inference.test

+47
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,53 @@ i(b, a, b)
696696
i(a, b, b) # E: Argument 1 to "i" has incompatible type List[int]; expected List[str]
697697
[builtins fixtures/list.pyi]
698698

699+
[case testCallableListJoinInference]
700+
from typing import Any, Callable
701+
702+
def fun() -> None:
703+
callbacks = [
704+
callback1,
705+
callback2,
706+
]
707+
708+
for c in callbacks:
709+
call(c, 1234) # this must not fail
710+
711+
def callback1(i: int) -> int:
712+
return i
713+
def callback2(i: int) -> str:
714+
return 'hello'
715+
def call(c: Callable[[int], Any], i: int) -> None:
716+
c(i)
717+
[builtins fixtures/list.pyi]
718+
[out]
719+
720+
[case testCallableMeetAndJoin]
721+
# flags: --python-version 3.6
722+
from typing import Callable, Any, TypeVar
723+
724+
class A: ...
725+
class B(A): ...
726+
727+
def f(c: Callable[[B], int]) -> None: ...
728+
729+
c: Callable[[A], int]
730+
d: Callable[[B], int]
731+
732+
lst = [c, d]
733+
reveal_type(lst) # E: Revealed type is 'builtins.list[def (__main__.B) -> builtins.int]'
734+
735+
T = TypeVar('T')
736+
def meet_test(x: Callable[[T], int], y: Callable[[T], int]) -> T: ...
737+
738+
CA = Callable[[A], A]
739+
CB = Callable[[B], B]
740+
741+
ca: Callable[[CA], int]
742+
cb: Callable[[CB], int]
743+
reveal_type(meet_test(ca, cb)) # E: Revealed type is 'def (__main__.A) -> __main__.B'
744+
[builtins fixtures/list.pyi]
745+
[out]
699746

700747
[case testUnionInferenceWithTypeVarValues]
701748
from typing import TypeVar, Union

0 commit comments

Comments
 (0)