Skip to content

Use checkmember.py to check protocol subtyping #18943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/mypy_primer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
--debug \
--additional-flags="--debug-serialize" \
--output concise \
--show-speed-regression \
| tee diff_${{ matrix.shard-index }}.txt
) || [ $? -eq 1 ]
- if: ${{ matrix.shard-index == 0 }}
Expand Down
5 changes: 3 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mypy import errorcodes as codes, join, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange
from mypy.checker_state import checker_state
from mypy.checkmember import (
MemberContext,
analyze_class_attribute_access,
Expand Down Expand Up @@ -455,7 +456,7 @@ def check_first_pass(self) -> None:
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
Expand Down Expand Up @@ -496,7 +497,7 @@ def check_second_pass(
This goes through deferred nodes, returning True if there were any.
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
Expand Down
30 changes: 30 additions & 0 deletions mypy/checker_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import Final

from mypy.checker_shared import TypeCheckerSharedApi

# This is global mutable state. Don't add anything here unless there's a very
# good reason.


class TypeCheckerState:
# Wrap this in a class since it's faster that using a module-level attribute.

def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None:
# Value varies by file being processed
self.type_checker = type_checker

@contextmanager
def set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
saved = self.type_checker
self.type_checker = value
try:
yield
finally:
self.type_checker = saved


checker_state: Final = TypeCheckerState(type_checker=None)
56 changes: 25 additions & 31 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
is_self: bool = False,
rvalue: Expression | None = None,
suppress_errors: bool = False,
preserve_type_var_ids: bool = False,
) -> None:
self.is_lvalue = is_lvalue
self.is_super = is_super
Expand All @@ -113,6 +114,10 @@ def __init__(
assert is_lvalue
self.rvalue = rvalue
self.suppress_errors = suppress_errors
# This attribute is only used to preserve old protocol member access logic.
# It is needed to avoid infinite recursion in cases involving self-referential
# generic methods, see find_member() for details. Do not use for other purposes!
self.preserve_type_var_ids = preserve_type_var_ids

def named_type(self, name: str) -> Instance:
return self.chk.named_type(name)
Expand Down Expand Up @@ -143,6 +148,7 @@ def copy_modified(
no_deferral=self.no_deferral,
rvalue=self.rvalue,
suppress_errors=self.suppress_errors,
preserve_type_var_ids=self.preserve_type_var_ids,
)
if self_type is not None:
mx.self_type = self_type
Expand Down Expand Up @@ -232,8 +238,6 @@ def analyze_member_access(
def _analyze_member_access(
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
) -> Type:
# TODO: This and following functions share some logic with subtypes.find_member;
# consider refactoring.
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return analyze_instance_member_access(name, typ, mx, override_info)
Expand Down Expand Up @@ -358,7 +362,8 @@ def analyze_instance_member_access(
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not mx.preserve_type_var_ids:
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if isinstance(method, (FuncDef, OverloadedFuncDef)) and method.is_trivial_self:
signature = bind_self_fast(signature, mx.self_type)
Expand Down Expand Up @@ -943,7 +948,8 @@ def analyze_var(
def expand_without_binding(
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
) -> Type:
typ = freshen_all_functions_type_vars(typ)
if not mx.preserve_type_var_ids:
typ = freshen_all_functions_type_vars(typ)
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
expanded = expand_type_by_instance(typ, itype)
freeze_all_type_vars(expanded)
Expand All @@ -958,7 +964,8 @@ def expand_and_bind_callable(
mx: MemberContext,
is_trivial_self: bool,
) -> Type:
functype = freshen_all_functions_type_vars(functype)
if not mx.preserve_type_var_ids:
functype = freshen_all_functions_type_vars(functype)
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
assert isinstance(typ, FunctionLike)
if is_trivial_self:
Expand Down Expand Up @@ -1056,10 +1063,12 @@ def f(self: S) -> T: ...
return functype
else:
selfarg = get_proper_type(item.arg_types[0])
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
if subtypes.is_subtype(
# This matches similar special-casing in bind_self(), see more details there.
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
if self_callable or subtypes.is_subtype(
dispatched_arg_type,
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
erase_typevars(erase_to_bound(selfarg)),
# This is to work around the fact that erased ParamSpec and TypeVarTuple
# callables are not always compatible with non-erased ones both ways.
Expand Down Expand Up @@ -1220,9 +1229,6 @@ def analyze_class_attribute_access(
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
)
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
)
t = get_proper_type(t)
is_trivial_self = False
if isinstance(node.node, Decorator):
Expand All @@ -1236,8 +1242,7 @@ def analyze_class_attribute_access(
t,
isuper,
is_classmethod,
is_staticmethod,
mx.self_type,
mx,
original_vars=original_vars,
is_trivial_self=is_trivial_self,
)
Expand Down Expand Up @@ -1372,8 +1377,7 @@ def add_class_tvars(
t: ProperType,
isuper: Instance | None,
is_classmethod: bool,
is_staticmethod: bool,
original_type: Type,
mx: MemberContext,
original_vars: Sequence[TypeVarLikeType] | None = None,
is_trivial_self: bool = False,
) -> Type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not appear to be a performance bottleneck (at least in self check).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL If you will have time, could you please check if there is any slowness because of bind_self() and check_self_arg()? Although they are not modified, they may be called much more often now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_self_arg could be more expensive -- it appears to consume an extra ~0.5% of runtime in this PR. We are now spending maybe 2-3% of CPU in it, so it's quite hot, but it already was pretty hot before this PR. This could be noise though.

I didn't see any major change in bind_self when doing self check, though it's pretty hot both before and after, though less hot than check_self_arg.

Expand All @@ -1392,9 +1396,6 @@ class B(A[str]): pass
isuper: Current instance mapped to the superclass where method was defined, this
is usually done by map_instance_to_supertype()
is_classmethod: True if this method is decorated with @classmethod
is_staticmethod: True if this method is decorated with @staticmethod
original_type: The value of the type B in the expression B.foo() or the corresponding
component in case of a union (this is used to bind the self-types)
original_vars: Type variables of the class callable on which the method was accessed
is_trivial_self: if True, we can use fast path for bind_self().
Returns:
Expand All @@ -1416,14 +1417,14 @@ class B(A[str]): pass
# (i.e. appear in the return type of the class object on which the method was accessed).
if isinstance(t, CallableType):
tvars = original_vars if original_vars is not None else []
t = freshen_all_functions_type_vars(t)
if not mx.preserve_type_var_ids:
t = freshen_all_functions_type_vars(t)
if is_classmethod:
if is_trivial_self:
t = bind_self_fast(t, original_type)
t = bind_self_fast(t, mx.self_type)
else:
t = bind_self(t, original_type, is_classmethod=True)
if is_classmethod or is_staticmethod:
assert isuper is not None
t = bind_self(t, mx.self_type, is_classmethod=True)
if isuper is not None:
t = expand_type_by_instance(t, isuper)
freeze_all_type_vars(t)
return t.copy_modified(variables=list(tvars) + list(t.variables))
Expand All @@ -1432,14 +1433,7 @@ class B(A[str]): pass
[
cast(
CallableType,
add_class_tvars(
item,
isuper,
is_classmethod,
is_staticmethod,
original_type,
original_vars=original_vars,
),
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
)
for item in t.items
]
Expand Down
9 changes: 7 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,8 +2220,13 @@ def report_protocol_problems(
exp = get_proper_type(exp)
got = get_proper_type(got)
setter_suffix = " setter type" if is_lvalue else ""
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
got, (CallableType, Overloaded)
if (
not isinstance(exp, (CallableType, Overloaded))
or not isinstance(got, (CallableType, Overloaded))
# If expected type is a type object, it means it is a nested class.
# Showing constructor signature in errors would be confusing in this case,
# since we don't check the signature, only subclassing of type objects.
or exp.is_type_obj()
):
self.note(
"{}: expected{} {}, got {}".format(
Expand Down
8 changes: 5 additions & 3 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,13 @@ class C: pass
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Callable, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar

from mypy_extensions import mypyc_attr, trait

from mypy.errorcodes import ErrorCode
from mypy.lookup import lookup_fully_qualified
from mypy.message_registry import ErrorMessage
from mypy.messages import MessageBuilder
from mypy.nodes import (
ArgKind,
CallExpr,
Expand All @@ -138,7 +137,6 @@ class C: pass
TypeInfo,
)
from mypy.options import Options
from mypy.tvar_scope import TypeVarLikeScope
from mypy.types import (
CallableType,
FunctionLike,
Expand All @@ -149,6 +147,10 @@ class C: pass
UnboundType,
)

if TYPE_CHECKING:
from mypy.messages import MessageBuilder
from mypy.tvar_scope import TypeVarLikeScope


@trait
class TypeAnalyzerPluginInterface:
Expand Down
84 changes: 79 additions & 5 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mypy.applytype
import mypy.constraints
import mypy.typeops
from mypy.checker_state import checker_state
from mypy.erasetype import erase_type
from mypy.expandtype import (
expand_self_type,
Expand All @@ -26,6 +27,7 @@
COVARIANT,
INVARIANT,
VARIANCE_NOT_READY,
Context,
Decorator,
FuncBase,
OverloadedFuncDef,
Expand Down Expand Up @@ -717,8 +719,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Instance):
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# OK, a callable can implement a protocol with a `__call__` member.
# TODO: we should probably explicitly exclude self-types in this case.
call = find_member("__call__", right, left, is_operator=True)
call = find_member("__call__", right, right, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
if len(right.type.protocol_members) == 1:
Expand Down Expand Up @@ -954,7 +955,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
if isinstance(right, Instance):
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# same as for CallableType
call = find_member("__call__", right, left, is_operator=True)
call = find_member("__call__", right, right, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
if len(right.type.protocol_members) == 1:
Expand Down Expand Up @@ -1266,14 +1267,87 @@ def find_member(
is_operator: bool = False,
class_obj: bool = False,
is_lvalue: bool = False,
) -> Type | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the remaining performance regression comes from find_member. Would it be feasible to add a fast path that could be used in the majority of (simple) cases? This only makes sense if the semantics would remain identical. We'd first try the fast path, and if it can't be used (not a simple case), we'd fall back to the general implementation that is added here (after from mypy.checkmember import ...). The fast path might look a bit like find_member_simple.

That fast path might cover access to normal attribute/method via instance when there are no self types or properties, for example. Maybe we can avoid creating MemberContext and using filter_errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be harder to tune, so although I agree we should do this, i would leave this optimization for later.

type_checker = checker_state.type_checker
if type_checker is None:
# Unfortunately, there are many scenarios where someone calls is_subtype() before
# type checking phase. In this case we fallback to old (incomplete) logic.
# TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins).
return find_member_simple(
name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue
)

# We don't use ATTR_DEFINED error code below (since missing attributes can cause various
# other error codes), instead we perform quick node lookup with all the fallbacks.
info = itype.type
sym = info.get(name)
node = sym.node if sym else None
if not node:
name_not_found = True
if (
name not in ["__getattr__", "__setattr__", "__getattribute__"]
and not is_operator
and not class_obj
and itype.extra_attrs is None # skip ModuleType.__getattr__
):
for method_name in ("__getattribute__", "__getattr__"):
method = info.get_method(method_name)
if method and method.info.fullname != "builtins.object":
name_not_found = False
break
if name_not_found:
if info.fallback_to_any or class_obj and info.meta_fallback_to_any:
return AnyType(TypeOfAny.special_form)
if itype.extra_attrs and name in itype.extra_attrs.attrs:
return itype.extra_attrs.attrs[name]
return None

from mypy.checkmember import (
MemberContext,
analyze_class_attribute_access,
analyze_instance_member_access,
)

mx = MemberContext(
is_lvalue=is_lvalue,
is_super=False,
is_operator=is_operator,
original_type=itype,
self_type=subtype,
context=Context(), # all errors are filtered, but this is a required argument
chk=type_checker,
suppress_errors=True,
# This is needed to avoid infinite recursion in situations involving protocols like
# class P(Protocol[T]):
# def combine(self, other: P[S]) -> P[Tuple[T, S]]: ...
# Normally we call freshen_all_functions_type_vars() during attribute access,
# to avoid type variable id collisions, but for protocols this means we can't
# use the assumption stack, that will grow indefinitely.
# TODO: find a cleaner solution that doesn't involve massive perf impact.
preserve_type_var_ids=True,
)
with type_checker.msg.filter_errors(filter_deprecated=True):
if class_obj:
fallback = itype.type.metaclass_type or mx.named_type("builtins.type")
return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell why this PR doesn't fix #17567, because this should work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this is because the fallback to metaclass is handled by the caller of analyze_class_attribute_access(), not by analyze_class_attribute_access() itself. I guess I will need to copy this fallback logic (which may be a bit ugly because we will also need to update get_member_flags()). I will look into it later.

else:
return analyze_instance_member_access(name, itype, mx, info)


def find_member_simple(
name: str,
itype: Instance,
subtype: Type,
*,
is_operator: bool = False,
class_obj: bool = False,
is_lvalue: bool = False,
) -> Type | None:
"""Find the type of member by 'name' in 'itype's TypeInfo.

Find the member type after applying type arguments from 'itype', and binding
'self' to 'subtype'. Return None if member was not found.
"""
# TODO: this code shares some logic with checkmember.analyze_member_access,
# consider refactoring.
info = itype.type
method = info.get_method(name)
if method:
Expand Down
Loading