Skip to content

Commit caff030

Browse files
authored
Add type inference for class object vs generic protocol (#13511)
I forgot to add this to yesterdays PR #13501
1 parent d68b1c6 commit caff030

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

mypy/constraints.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,33 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
541541
template.type.inferring.pop()
542542
return res
543543
if isinstance(actual, CallableType) and actual.fallback is not None:
544+
if actual.is_type_obj() and template.type.is_protocol:
545+
ret_type = get_proper_type(actual.ret_type)
546+
if isinstance(ret_type, TupleType):
547+
ret_type = mypy.typeops.tuple_fallback(ret_type)
548+
if isinstance(ret_type, Instance):
549+
if self.direction == SUBTYPE_OF:
550+
subtype = template
551+
else:
552+
subtype = ret_type
553+
res.extend(
554+
self.infer_constraints_from_protocol_members(
555+
ret_type, template, subtype, template, class_obj=True
556+
)
557+
)
544558
actual = actual.fallback
559+
if isinstance(actual, TypeType) and template.type.is_protocol:
560+
if isinstance(actual.item, Instance):
561+
if self.direction == SUBTYPE_OF:
562+
subtype = template
563+
else:
564+
subtype = actual.item
565+
res.extend(
566+
self.infer_constraints_from_protocol_members(
567+
actual.item, template, subtype, template, class_obj=True
568+
)
569+
)
570+
545571
if isinstance(actual, Overloaded) and actual.fallback is not None:
546572
actual = actual.fallback
547573
if isinstance(actual, TypedDictType):
@@ -715,6 +741,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
715741
)
716742
instance.type.inferring.pop()
717743
return res
744+
if res:
745+
return res
746+
718747
if isinstance(actual, AnyType):
719748
return self.infer_against_any(template.args, actual)
720749
if (
@@ -740,7 +769,12 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
740769
return []
741770

742771
def infer_constraints_from_protocol_members(
743-
self, instance: Instance, template: Instance, subtype: Type, protocol: Instance
772+
self,
773+
instance: Instance,
774+
template: Instance,
775+
subtype: Type,
776+
protocol: Instance,
777+
class_obj: bool = False,
744778
) -> list[Constraint]:
745779
"""Infer constraints for situations where either 'template' or 'instance' is a protocol.
746780
@@ -750,7 +784,7 @@ def infer_constraints_from_protocol_members(
750784
"""
751785
res = []
752786
for member in protocol.type.protocol_members:
753-
inst = mypy.subtypes.find_member(member, instance, subtype)
787+
inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj)
754788
temp = mypy.subtypes.find_member(member, template, subtype)
755789
if inst is None or temp is None:
756790
return [] # See #11020

test-data/unit/check-protocols.test

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3517,3 +3517,32 @@ test(c) # E: Argument 1 to "test" has incompatible type "Type[C]"; expected "P"
35173517
# N: def [T] foo(arg: T) -> T \
35183518
# N: Got: \
35193519
# N: def [T] foo(self: T) -> Union[T, int]
3520+
3521+
[case testProtocolClassObjectInference]
3522+
from typing import Any, Protocol, TypeVar
3523+
3524+
T = TypeVar("T", contravariant=True)
3525+
class P(Protocol[T]):
3526+
def foo(self, obj: T) -> int: ...
3527+
3528+
class B:
3529+
def foo(self) -> int: ...
3530+
3531+
S = TypeVar("S")
3532+
def test(arg: P[S]) -> S: ...
3533+
reveal_type(test(B)) # N: Revealed type is "__main__.B"
3534+
3535+
[case testProtocolTypeTypeInference]
3536+
from typing import Any, Protocol, TypeVar, Type
3537+
3538+
T = TypeVar("T", contravariant=True)
3539+
class P(Protocol[T]):
3540+
def foo(self, obj: T) -> int: ...
3541+
3542+
class B:
3543+
def foo(self) -> int: ...
3544+
3545+
S = TypeVar("S")
3546+
def test(arg: P[S]) -> S: ...
3547+
b: Type[B]
3548+
reveal_type(test(b)) # N: Revealed type is "__main__.B"

0 commit comments

Comments
 (0)