Skip to content

Commit 249f3f8

Browse files
authored
Fix inference for overloaded __call__ with generic self (#16053)
Fixes #8283 Co-authored-by: ilevkivskyi
1 parent ba978f4 commit 249f3f8

File tree

5 files changed

+76
-30
lines changed

5 files changed

+76
-30
lines changed

mypy/checkexpr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,7 @@ def check_call(
14751475
callable_node: Expression | None = None,
14761476
callable_name: str | None = None,
14771477
object_type: Type | None = None,
1478+
original_type: Type | None = None,
14781479
) -> tuple[Type, Type]:
14791480
"""Type check a call.
14801481
@@ -1537,7 +1538,7 @@ def check_call(
15371538
is_super=False,
15381539
is_operator=True,
15391540
msg=self.msg,
1540-
original_type=callee,
1541+
original_type=original_type or callee,
15411542
chk=self.chk,
15421543
in_literal_context=self.is_literal_context(),
15431544
)
@@ -1578,6 +1579,7 @@ def check_call(
15781579
callable_node,
15791580
callable_name,
15801581
object_type,
1582+
original_type=callee,
15811583
)
15821584
else:
15831585
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

mypy/checkmember.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,12 @@ def analyze_instance_member_access(
331331
signature = method.type
332332
signature = freshen_all_functions_type_vars(signature)
333333
if not method.is_static:
334-
if name != "__call__":
335-
# TODO: use proper treatment of special methods on unions instead
336-
# of this hack here and below (i.e. mx.self_type).
337-
dispatched_type = meet.meet_types(mx.original_type, typ)
338-
signature = check_self_arg(
339-
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
340-
)
334+
# TODO: use proper treatment of special methods on unions instead
335+
# of this hack here and below (i.e. mx.self_type).
336+
dispatched_type = meet.meet_types(mx.original_type, typ)
337+
signature = check_self_arg(
338+
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
339+
)
341340
signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class)
342341
# TODO: should we skip these steps for static methods as well?
343342
# Since generic static methods should not be allowed.

mypy/subtypes.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -454,19 +454,22 @@ def visit_instance(self, left: Instance) -> bool:
454454
if isinstance(unpacked, Instance):
455455
return self._is_subtype(left, unpacked)
456456
if left.type.has_base(right.partial_fallback.type.fullname):
457-
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
458-
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
459-
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
460-
if len(mapped.args) == 1 and isinstance(mapped.args[0], UnpackType):
461-
unpacked = get_proper_type(mapped.args[0].type)
462-
if isinstance(unpacked, Instance):
463-
assert unpacked.type.fullname == "builtins.tuple"
464-
if isinstance(get_proper_type(unpacked.args[0]), AnyType):
465-
return not self.proper_subtype
466-
if mapped.type.fullname == "builtins.tuple" and isinstance(
467-
get_proper_type(mapped.args[0]), AnyType
468-
):
469-
return not self.proper_subtype
457+
if not self.proper_subtype:
458+
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
459+
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
460+
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
461+
for arg in map(get_proper_type, mapped.args):
462+
if isinstance(arg, UnpackType):
463+
unpacked = get_proper_type(arg.type)
464+
if not isinstance(unpacked, Instance):
465+
break
466+
assert unpacked.type.fullname == "builtins.tuple"
467+
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
468+
break
469+
elif not isinstance(arg, AnyType):
470+
break
471+
else:
472+
return True
470473
return False
471474
if isinstance(right, TypeVarTupleType):
472475
# tuple[Any, ...] is like Any in the world of tuples (see special case above).
@@ -534,15 +537,19 @@ def visit_instance(self, left: Instance) -> bool:
534537
right_args = (
535538
right_prefix + (TupleType(list(right_middle), fallback),) + right_suffix
536539
)
537-
if len(t.args) == 1 and isinstance(t.args[0], UnpackType):
538-
unpacked = get_proper_type(t.args[0].type)
539-
if isinstance(unpacked, Instance):
540-
assert unpacked.type.fullname == "builtins.tuple"
541-
if (
542-
isinstance(get_proper_type(unpacked.args[0]), AnyType)
543-
and not self.proper_subtype
544-
):
545-
return True
540+
if not self.proper_subtype:
541+
for arg in map(get_proper_type, t.args):
542+
if isinstance(arg, UnpackType):
543+
unpacked = get_proper_type(arg.type)
544+
if not isinstance(unpacked, Instance):
545+
break
546+
assert unpacked.type.fullname == "builtins.tuple"
547+
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
548+
break
549+
elif not isinstance(arg, AnyType):
550+
break
551+
else:
552+
return True
546553
type_params = zip(left_args, right_args, right.type.defn.type_vars)
547554
else:
548555
type_params = zip(t.args, right.args, right.type.defn.type_vars)

test-data/unit/check-overloading.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6650,3 +6650,27 @@ def d(x: int) -> int: ...
66506650
def d(f: int, *, x: int) -> str: ...
66516651
def d(*args, **kwargs): ...
66526652
[builtins fixtures/tuple.pyi]
6653+
6654+
[case testOverloadCallableGenericSelf]
6655+
from typing import Any, TypeVar, Generic, overload, reveal_type
6656+
6657+
T = TypeVar("T")
6658+
6659+
class MyCallable(Generic[T]):
6660+
def __init__(self, t: T):
6661+
self.t = t
6662+
6663+
@overload
6664+
def __call__(self: "MyCallable[int]") -> str: ...
6665+
@overload
6666+
def __call__(self: "MyCallable[str]") -> int: ...
6667+
def __call__(self): ...
6668+
6669+
c = MyCallable(5)
6670+
reveal_type(c) # N: Revealed type is "__main__.MyCallable[builtins.int]"
6671+
reveal_type(c()) # N: Revealed type is "builtins.str"
6672+
6673+
c2 = MyCallable("test")
6674+
reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]"
6675+
reveal_type(c2()) # should be int # N: Revealed type is "builtins.int"
6676+
[builtins fixtures/tuple.pyi]

test-data/unit/check-tuples.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,21 @@ def foo(o: CallableTuple) -> int:
14341434
class CallableTuple(Tuple[str, int]):
14351435
def __call__(self, n: int, m: int) -> int:
14361436
return n
1437+
[builtins fixtures/tuple.pyi]
1438+
1439+
[case testTypeTupleGenericCall]
1440+
from typing import Generic, Tuple, TypeVar
1441+
1442+
T = TypeVar('T')
14371443

1444+
def foo(o: CallableTuple[int]) -> int:
1445+
reveal_type(o) # N: Revealed type is "Tuple[builtins.str, builtins.int, fallback=__main__.CallableTuple[builtins.int]]"
1446+
reveal_type(o.count(3)) # N: Revealed type is "builtins.int"
1447+
return o(1, 2)
1448+
1449+
class CallableTuple(Tuple[str, T]):
1450+
def __call__(self, n: int, m: int) -> int:
1451+
return n
14381452
[builtins fixtures/tuple.pyi]
14391453

14401454
[case testTupleCompatibleWithSequence]

0 commit comments

Comments
 (0)