Skip to content

Commit 9431d47

Browse files
authored
Allow classes as protocol implementations (#13501)
Fixes #4536 This use case is specified by PEP 544 but was not implemented. Only instances (not class objects) were allowed as protocol implementations. The PR is quite straightforward, essentially I just pass a `class_obj` flag everywhere to know whether we need or not to bind self in a method.
1 parent 09b0fa4 commit 9431d47

File tree

3 files changed

+526
-42
lines changed

3 files changed

+526
-42
lines changed

mypy/messages.py

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,9 @@ def incompatible_argument_note(
743743
context: Context,
744744
code: ErrorCode | None,
745745
) -> None:
746-
if isinstance(original_caller_type, (Instance, TupleType, TypedDictType)):
746+
if isinstance(
747+
original_caller_type, (Instance, TupleType, TypedDictType, TypeType, CallableType)
748+
):
747749
if isinstance(callee_type, Instance) and callee_type.type.is_protocol:
748750
self.report_protocol_problems(
749751
original_caller_type, callee_type, context, code=code
@@ -1791,7 +1793,7 @@ def impossible_intersection(
17911793

17921794
def report_protocol_problems(
17931795
self,
1794-
subtype: Instance | TupleType | TypedDictType,
1796+
subtype: Instance | TupleType | TypedDictType | TypeType | CallableType,
17951797
supertype: Instance,
17961798
context: Context,
17971799
*,
@@ -1811,15 +1813,15 @@ def report_protocol_problems(
18111813
exclusions: dict[type, list[str]] = {
18121814
TypedDictType: ["typing.Mapping"],
18131815
TupleType: ["typing.Iterable", "typing.Sequence"],
1814-
Instance: [],
18151816
}
1816-
if supertype.type.fullname in exclusions[type(subtype)]:
1817+
if supertype.type.fullname in exclusions.get(type(subtype), []):
18171818
return
18181819
if any(isinstance(tp, UninhabitedType) for tp in get_proper_types(supertype.args)):
18191820
# We don't want to add notes for failed inference (e.g. Iterable[<nothing>]).
18201821
# This will be only confusing a user even more.
18211822
return
18221823

1824+
class_obj = False
18231825
if isinstance(subtype, TupleType):
18241826
if not isinstance(subtype.partial_fallback, Instance):
18251827
return
@@ -1828,6 +1830,21 @@ def report_protocol_problems(
18281830
if not isinstance(subtype.fallback, Instance):
18291831
return
18301832
subtype = subtype.fallback
1833+
elif isinstance(subtype, TypeType):
1834+
if not isinstance(subtype.item, Instance):
1835+
return
1836+
class_obj = True
1837+
subtype = subtype.item
1838+
elif isinstance(subtype, CallableType):
1839+
if not subtype.is_type_obj():
1840+
return
1841+
ret_type = get_proper_type(subtype.ret_type)
1842+
if isinstance(ret_type, TupleType):
1843+
ret_type = ret_type.partial_fallback
1844+
if not isinstance(ret_type, Instance):
1845+
return
1846+
class_obj = True
1847+
subtype = ret_type
18311848

18321849
# Report missing members
18331850
missing = get_missing_protocol_members(subtype, supertype)
@@ -1836,20 +1853,29 @@ def report_protocol_problems(
18361853
and len(missing) < len(supertype.type.protocol_members)
18371854
and len(missing) <= MAX_ITEMS
18381855
):
1839-
self.note(
1840-
'"{}" is missing following "{}" protocol member{}:'.format(
1841-
subtype.type.name, supertype.type.name, plural_s(missing)
1842-
),
1843-
context,
1844-
code=code,
1845-
)
1846-
self.note(", ".join(missing), context, offset=OFFSET, code=code)
1856+
if missing == ["__call__"] and class_obj:
1857+
self.note(
1858+
'"{}" has constructor incompatible with "__call__" of "{}"'.format(
1859+
subtype.type.name, supertype.type.name
1860+
),
1861+
context,
1862+
code=code,
1863+
)
1864+
else:
1865+
self.note(
1866+
'"{}" is missing following "{}" protocol member{}:'.format(
1867+
subtype.type.name, supertype.type.name, plural_s(missing)
1868+
),
1869+
context,
1870+
code=code,
1871+
)
1872+
self.note(", ".join(missing), context, offset=OFFSET, code=code)
18471873
elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members):
18481874
# This is an obviously wrong type: too many missing members
18491875
return
18501876

18511877
# Report member type conflicts
1852-
conflict_types = get_conflict_protocol_types(subtype, supertype)
1878+
conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj)
18531879
if conflict_types and (
18541880
not is_subtype(subtype, erase_type(supertype))
18551881
or not subtype.type.defn.type_vars
@@ -1875,29 +1901,43 @@ def report_protocol_problems(
18751901
else:
18761902
self.note("Expected:", context, offset=OFFSET, code=code)
18771903
if isinstance(exp, CallableType):
1878-
self.note(pretty_callable(exp), context, offset=2 * OFFSET, code=code)
1904+
self.note(
1905+
pretty_callable(exp, skip_self=class_obj),
1906+
context,
1907+
offset=2 * OFFSET,
1908+
code=code,
1909+
)
18791910
else:
18801911
assert isinstance(exp, Overloaded)
1881-
self.pretty_overload(exp, context, 2 * OFFSET, code=code)
1912+
self.pretty_overload(
1913+
exp, context, 2 * OFFSET, code=code, skip_self=class_obj
1914+
)
18821915
self.note("Got:", context, offset=OFFSET, code=code)
18831916
if isinstance(got, CallableType):
1884-
self.note(pretty_callable(got), context, offset=2 * OFFSET, code=code)
1917+
self.note(
1918+
pretty_callable(got, skip_self=class_obj),
1919+
context,
1920+
offset=2 * OFFSET,
1921+
code=code,
1922+
)
18851923
else:
18861924
assert isinstance(got, Overloaded)
1887-
self.pretty_overload(got, context, 2 * OFFSET, code=code)
1925+
self.pretty_overload(
1926+
got, context, 2 * OFFSET, code=code, skip_self=class_obj
1927+
)
18881928
self.print_more(conflict_types, context, OFFSET, MAX_ITEMS, code=code)
18891929

18901930
# Report flag conflicts (i.e. settable vs read-only etc.)
1891-
conflict_flags = get_bad_protocol_flags(subtype, supertype)
1931+
conflict_flags = get_bad_protocol_flags(subtype, supertype, class_obj=class_obj)
18921932
for name, subflags, superflags in conflict_flags[:MAX_ITEMS]:
1893-
if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags:
1933+
if not class_obj and IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags:
18941934
self.note(
18951935
"Protocol member {}.{} expected instance variable,"
18961936
" got class variable".format(supertype.type.name, name),
18971937
context,
18981938
code=code,
18991939
)
1900-
if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags:
1940+
if not class_obj and IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags:
19011941
self.note(
19021942
"Protocol member {}.{} expected class variable,"
19031943
" got instance variable".format(supertype.type.name, name),
@@ -1919,6 +1959,13 @@ def report_protocol_problems(
19191959
context,
19201960
code=code,
19211961
)
1962+
if class_obj and IS_SETTABLE in superflags and IS_CLASSVAR not in subflags:
1963+
self.note(
1964+
"Only class variables allowed for class object access on protocols,"
1965+
' {} is an instance variable of "{}"'.format(name, subtype.type.name),
1966+
context,
1967+
code=code,
1968+
)
19221969
self.print_more(conflict_flags, context, OFFSET, MAX_ITEMS, code=code)
19231970

19241971
def pretty_overload(
@@ -1930,6 +1977,7 @@ def pretty_overload(
19301977
add_class_or_static_decorator: bool = False,
19311978
allow_dups: bool = False,
19321979
code: ErrorCode | None = None,
1980+
skip_self: bool = False,
19331981
) -> None:
19341982
for item in tp.items:
19351983
self.note("@overload", context, offset=offset, allow_dups=allow_dups, code=code)
@@ -1940,7 +1988,11 @@ def pretty_overload(
19401988
self.note(decorator, context, offset=offset, allow_dups=allow_dups, code=code)
19411989

19421990
self.note(
1943-
pretty_callable(item), context, offset=offset, allow_dups=allow_dups, code=code
1991+
pretty_callable(item, skip_self=skip_self),
1992+
context,
1993+
offset=offset,
1994+
allow_dups=allow_dups,
1995+
code=code,
19441996
)
19451997

19461998
def print_more(
@@ -2373,10 +2425,14 @@ def pretty_class_or_static_decorator(tp: CallableType) -> str | None:
23732425
return None
23742426

23752427

2376-
def pretty_callable(tp: CallableType) -> str:
2428+
def pretty_callable(tp: CallableType, skip_self: bool = False) -> str:
23772429
"""Return a nice easily-readable representation of a callable type.
23782430
For example:
23792431
def [T <: int] f(self, x: int, y: T) -> None
2432+
2433+
If skip_self is True, print an actual callable type, as it would appear
2434+
when bound on an instance/class, rather than how it would appear in the
2435+
defining statement.
23802436
"""
23812437
s = ""
23822438
asterisk = False
@@ -2420,7 +2476,11 @@ def [T <: int] f(self, x: int, y: T) -> None
24202476
and hasattr(tp.definition, "arguments")
24212477
):
24222478
definition_arg_names = [arg.variable.name for arg in tp.definition.arguments]
2423-
if len(definition_arg_names) > len(tp.arg_names) and definition_arg_names[0]:
2479+
if (
2480+
len(definition_arg_names) > len(tp.arg_names)
2481+
and definition_arg_names[0]
2482+
and not skip_self
2483+
):
24242484
if s:
24252485
s = ", " + s
24262486
s = definition_arg_names[0] + s
@@ -2487,7 +2547,9 @@ def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
24872547
return missing
24882548

24892549

2490-
def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[str, Type, Type]]:
2550+
def get_conflict_protocol_types(
2551+
left: Instance, right: Instance, class_obj: bool = False
2552+
) -> list[tuple[str, Type, Type]]:
24912553
"""Find members that are defined in 'left' but have incompatible types.
24922554
Return them as a list of ('member', 'got', 'expected').
24932555
"""
@@ -2498,7 +2560,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
24982560
continue
24992561
supertype = find_member(member, right, left)
25002562
assert supertype is not None
2501-
subtype = find_member(member, left, left)
2563+
subtype = find_member(member, left, left, class_obj=class_obj)
25022564
if not subtype:
25032565
continue
25042566
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
@@ -2510,7 +2572,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
25102572

25112573

25122574
def get_bad_protocol_flags(
2513-
left: Instance, right: Instance
2575+
left: Instance, right: Instance, class_obj: bool = False
25142576
) -> list[tuple[str, set[int], set[int]]]:
25152577
"""Return all incompatible attribute flags for members that are present in both
25162578
'left' and 'right'.
@@ -2536,6 +2598,9 @@ def get_bad_protocol_flags(
25362598
and IS_SETTABLE not in subflags
25372599
or IS_CLASS_OR_STATIC in superflags
25382600
and IS_CLASS_OR_STATIC not in subflags
2601+
or class_obj
2602+
and IS_SETTABLE in superflags
2603+
and IS_CLASSVAR not in subflags
25392604
):
25402605
bad_flags.append((name, subflags, superflags))
25412606
return bad_flags

0 commit comments

Comments
 (0)