Skip to content

Allow classes as protocol implementations #13501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 91 additions & 26 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,9 @@ def incompatible_argument_note(
context: Context,
code: ErrorCode | None,
) -> None:
if isinstance(original_caller_type, (Instance, TupleType, TypedDictType)):
if isinstance(
original_caller_type, (Instance, TupleType, TypedDictType, TypeType, CallableType)
):
if isinstance(callee_type, Instance) and callee_type.type.is_protocol:
self.report_protocol_problems(
original_caller_type, callee_type, context, code=code
Expand Down Expand Up @@ -1791,7 +1793,7 @@ def impossible_intersection(

def report_protocol_problems(
self,
subtype: Instance | TupleType | TypedDictType,
subtype: Instance | TupleType | TypedDictType | TypeType | CallableType,
supertype: Instance,
context: Context,
*,
Expand All @@ -1811,15 +1813,15 @@ def report_protocol_problems(
exclusions: dict[type, list[str]] = {
TypedDictType: ["typing.Mapping"],
TupleType: ["typing.Iterable", "typing.Sequence"],
Instance: [],
}
if supertype.type.fullname in exclusions[type(subtype)]:
if supertype.type.fullname in exclusions.get(type(subtype), []):
return
if any(isinstance(tp, UninhabitedType) for tp in get_proper_types(supertype.args)):
# We don't want to add notes for failed inference (e.g. Iterable[<nothing>]).
# This will be only confusing a user even more.
return

class_obj = False
if isinstance(subtype, TupleType):
if not isinstance(subtype.partial_fallback, Instance):
return
Expand All @@ -1828,6 +1830,21 @@ def report_protocol_problems(
if not isinstance(subtype.fallback, Instance):
return
subtype = subtype.fallback
elif isinstance(subtype, TypeType):
if not isinstance(subtype.item, Instance):
return
class_obj = True
subtype = subtype.item
elif isinstance(subtype, CallableType):
if not subtype.is_type_obj():
return
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type

# Report missing members
missing = get_missing_protocol_members(subtype, supertype)
Expand All @@ -1836,20 +1853,29 @@ def report_protocol_problems(
and len(missing) < len(supertype.type.protocol_members)
and len(missing) <= MAX_ITEMS
):
self.note(
'"{}" is missing following "{}" protocol member{}:'.format(
subtype.type.name, supertype.type.name, plural_s(missing)
),
context,
code=code,
)
self.note(", ".join(missing), context, offset=OFFSET, code=code)
if missing == ["__call__"] and class_obj:
self.note(
'"{}" has constructor incompatible with "__call__" of "{}"'.format(
subtype.type.name, supertype.type.name
),
context,
code=code,
)
else:
self.note(
'"{}" is missing following "{}" protocol member{}:'.format(
subtype.type.name, supertype.type.name, plural_s(missing)
),
context,
code=code,
)
self.note(", ".join(missing), context, offset=OFFSET, code=code)
elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members):
# This is an obviously wrong type: too many missing members
return

# Report member type conflicts
conflict_types = get_conflict_protocol_types(subtype, supertype)
conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj)
if conflict_types and (
not is_subtype(subtype, erase_type(supertype))
or not subtype.type.defn.type_vars
Expand All @@ -1875,29 +1901,43 @@ def report_protocol_problems(
else:
self.note("Expected:", context, offset=OFFSET, code=code)
if isinstance(exp, CallableType):
self.note(pretty_callable(exp), context, offset=2 * OFFSET, code=code)
self.note(
pretty_callable(exp, skip_self=class_obj),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(exp, Overloaded)
self.pretty_overload(exp, context, 2 * OFFSET, code=code)
self.pretty_overload(
exp, context, 2 * OFFSET, code=code, skip_self=class_obj
)
self.note("Got:", context, offset=OFFSET, code=code)
if isinstance(got, CallableType):
self.note(pretty_callable(got), context, offset=2 * OFFSET, code=code)
self.note(
pretty_callable(got, skip_self=class_obj),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(got, Overloaded)
self.pretty_overload(got, context, 2 * OFFSET, code=code)
self.pretty_overload(
got, context, 2 * OFFSET, code=code, skip_self=class_obj
)
self.print_more(conflict_types, context, OFFSET, MAX_ITEMS, code=code)

# Report flag conflicts (i.e. settable vs read-only etc.)
conflict_flags = get_bad_protocol_flags(subtype, supertype)
conflict_flags = get_bad_protocol_flags(subtype, supertype, class_obj=class_obj)
for name, subflags, superflags in conflict_flags[:MAX_ITEMS]:
if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags:
if not class_obj and IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags:
self.note(
"Protocol member {}.{} expected instance variable,"
" got class variable".format(supertype.type.name, name),
context,
code=code,
)
if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags:
if not class_obj and IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags:
self.note(
"Protocol member {}.{} expected class variable,"
" got instance variable".format(supertype.type.name, name),
Expand All @@ -1919,6 +1959,13 @@ def report_protocol_problems(
context,
code=code,
)
if class_obj and IS_SETTABLE in superflags and IS_CLASSVAR not in subflags:
self.note(
"Only class variables allowed for class object access on protocols,"
' {} is an instance variable of "{}"'.format(name, subtype.type.name),
context,
code=code,
)
self.print_more(conflict_flags, context, OFFSET, MAX_ITEMS, code=code)

def pretty_overload(
Expand All @@ -1930,6 +1977,7 @@ def pretty_overload(
add_class_or_static_decorator: bool = False,
allow_dups: bool = False,
code: ErrorCode | None = None,
skip_self: bool = False,
) -> None:
for item in tp.items:
self.note("@overload", context, offset=offset, allow_dups=allow_dups, code=code)
Expand All @@ -1940,7 +1988,11 @@ def pretty_overload(
self.note(decorator, context, offset=offset, allow_dups=allow_dups, code=code)

self.note(
pretty_callable(item), context, offset=offset, allow_dups=allow_dups, code=code
pretty_callable(item, skip_self=skip_self),
context,
offset=offset,
allow_dups=allow_dups,
code=code,
)

def print_more(
Expand Down Expand Up @@ -2373,10 +2425,14 @@ def pretty_class_or_static_decorator(tp: CallableType) -> str | None:
return None


def pretty_callable(tp: CallableType) -> str:
def pretty_callable(tp: CallableType, skip_self: bool = False) -> str:
"""Return a nice easily-readable representation of a callable type.
For example:
def [T <: int] f(self, x: int, y: T) -> None

If skip_self is True, print an actual callable type, as it would appear
when bound on an instance/class, rather than how it would appear in the
defining statement.
"""
s = ""
asterisk = False
Expand Down Expand Up @@ -2420,7 +2476,11 @@ def [T <: int] f(self, x: int, y: T) -> None
and hasattr(tp.definition, "arguments")
):
definition_arg_names = [arg.variable.name for arg in tp.definition.arguments]
if len(definition_arg_names) > len(tp.arg_names) and definition_arg_names[0]:
if (
len(definition_arg_names) > len(tp.arg_names)
and definition_arg_names[0]
and not skip_self
):
if s:
s = ", " + s
s = definition_arg_names[0] + s
Expand Down Expand Up @@ -2487,7 +2547,9 @@ def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
return missing


def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[str, Type, Type]]:
def get_conflict_protocol_types(
left: Instance, right: Instance, class_obj: bool = False
) -> list[tuple[str, Type, Type]]:
"""Find members that are defined in 'left' but have incompatible types.
Return them as a list of ('member', 'got', 'expected').
"""
Expand All @@ -2498,7 +2560,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
continue
supertype = find_member(member, right, left)
assert supertype is not None
subtype = find_member(member, left, left)
subtype = find_member(member, left, left, class_obj=class_obj)
if not subtype:
continue
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
Expand All @@ -2510,7 +2572,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s


def get_bad_protocol_flags(
left: Instance, right: Instance
left: Instance, right: Instance, class_obj: bool = False
) -> list[tuple[str, set[int], set[int]]]:
"""Return all incompatible attribute flags for members that are present in both
'left' and 'right'.
Expand All @@ -2536,6 +2598,9 @@ def get_bad_protocol_flags(
and IS_SETTABLE not in subflags
or IS_CLASS_OR_STATIC in superflags
and IS_CLASS_OR_STATIC not in subflags
or class_obj
and IS_SETTABLE in superflags
and IS_CLASSVAR not in subflags
):
bad_flags.append((name, subflags, superflags))
return bad_flags
Expand Down
Loading