Skip to content

Allow using modules as subtypes of protocols #13513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def isinstance_proper_hook(ctx: FunctionContext) -> Type:
right = get_proper_type(ctx.arg_types[1][0])
for arg in ctx.arg_types[0]:
if (
is_improper_type(arg)
or isinstance(get_proper_type(arg), AnyType)
and is_dangerous_target(right)
):
is_improper_type(arg) or isinstance(get_proper_type(arg), AnyType)
) and is_dangerous_target(right):
if is_special_target(right):
return ctx.default_return_type
ctx.api.fail(
Expand Down
21 changes: 16 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,18 +2285,29 @@ def check_multiple_inheritance(self, typ: TypeInfo) -> None:
if name in base2.names and base2 not in base.mro:
self.check_compatibility(name, base, base2, typ)

def determine_type_of_class_member(self, sym: SymbolTableNode) -> Type | None:
def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
if sym.type is not None:
return sym.type
if isinstance(sym.node, FuncBase):
return self.function_type(sym.node)
if isinstance(sym.node, TypeInfo):
# nested class
return type_object_type(sym.node, self.named_type)
if sym.node.typeddict_type:
# We special-case TypedDict, because they don't define any constructor.
return self.expr_checker.typeddict_callable(sym.node)
else:
return type_object_type(sym.node, self.named_type)
if isinstance(sym.node, TypeVarExpr):
# Use of TypeVars is rejected in an expression/runtime context, so
# we don't need to check supertype compatibility for them.
return AnyType(TypeOfAny.special_form)
if isinstance(sym.node, TypeAlias):
with self.msg.filter_errors():
# Suppress any errors, they will be given when analyzing the corresponding node.
# Here we may have incorrect options and location context.
return self.expr_checker.alias_type_in_runtime_context(
sym.node, sym.node.no_args, sym.node
)
# TODO: handle more node kinds here.
return None

def check_compatibility(
Expand Down Expand Up @@ -2327,8 +2338,8 @@ class C(B, A[int]): ... # this is unsafe because...
return
first = base1.names[name]
second = base2.names[name]
first_type = get_proper_type(self.determine_type_of_class_member(first))
second_type = get_proper_type(self.determine_type_of_class_member(second))
first_type = get_proper_type(self.determine_type_of_member(first))
second_type = get_proper_type(self.determine_type_of_member(second))

if isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
if first_type.is_type_obj() and second_type.is_type_obj():
Expand Down
32 changes: 25 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
CallableType,
DeletedType,
ErasedType,
ExtraAttrs,
FunctionLike,
Instance,
LiteralType,
Expand Down Expand Up @@ -332,13 +333,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = erasetype.erase_typevars(result)
elif isinstance(node, MypyFile):
# Reference to a module object.
try:
result = self.named_type("types.ModuleType")
except KeyError:
# In test cases might 'types' may not be available.
# Fall back to a dummy 'object' type instead to
# avoid a crash.
result = self.named_type("builtins.object")
result = self.module_type(node)
elif isinstance(node, Decorator):
result = self.analyze_var_ref(node.var, e)
elif isinstance(node, TypeAlias):
Expand Down Expand Up @@ -374,6 +369,29 @@ def analyze_var_ref(self, var: Var, context: Context) -> Type:
# Implicit 'Any' type.
return AnyType(TypeOfAny.special_form)

def module_type(self, node: MypyFile) -> Instance:
try:
result = self.named_type("types.ModuleType")
except KeyError:
# In test cases might 'types' may not be available.
# Fall back to a dummy 'object' type instead to
# avoid a crash.
result = self.named_type("builtins.object")
module_attrs = {}
immutable = set()
for name, n in node.names.items():
if isinstance(n.node, Var) and n.node.is_final:
immutable.add(name)
typ = self.chk.determine_type_of_member(n)
if typ:
module_attrs[name] = typ
else:
# TODO: what to do about nested module references?
# They are non-trivial because there may be import cycles.
module_attrs[name] = AnyType(TypeOfAny.special_form)
result.extra_attrs = ExtraAttrs(module_attrs, immutable, node.fullname)
return result

def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
"""Type check a call expression."""
if e.analyzed:
Expand Down
3 changes: 3 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ def analyze_member_var_access(
return analyze_var(name, v, itype, info, mx, implicit=implicit)
elif isinstance(v, FuncDef):
assert False, "Did not expect a function"
elif isinstance(v, MypyFile):
mx.chk.module_refs.add(v.fullname)
return mx.chk.expr_checker.module_type(v)
elif (
not v
and name not in ["__getattr__", "__setattr__", "__getattribute__"]
Expand Down
2 changes: 1 addition & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def infer_constraints_from_protocol_members(
# The above is safe since at this point we know that 'instance' is a subtype
# of (erased) 'template', therefore it defines all protocol members
res.extend(infer_constraints(temp, inst, self.direction))
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol.type):
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol):
# Settable members are invariant, add opposite constraints
res.extend(infer_constraints(temp, inst, neg_op(self.direction)))
return res
Expand Down
43 changes: 22 additions & 21 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,7 @@ def report_protocol_problems(
return

class_obj = False
is_module = False
if isinstance(subtype, TupleType):
if not isinstance(subtype.partial_fallback, Instance):
return
Expand All @@ -1845,6 +1846,8 @@ def report_protocol_problems(
return
class_obj = True
subtype = ret_type
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
is_module = True

# Report missing members
missing = get_missing_protocol_members(subtype, supertype)
Expand Down Expand Up @@ -1881,11 +1884,8 @@ def report_protocol_problems(
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
):
self.note(
f"Following member(s) of {format_type(subtype)} have conflicts:",
context,
code=code,
)
type_name = format_type(subtype, module_names=True)
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
for name, got, exp in conflict_types[:MAX_ITEMS]:
exp = get_proper_type(exp)
got = get_proper_type(got)
Expand All @@ -1902,28 +1902,28 @@ def report_protocol_problems(
self.note("Expected:", context, offset=OFFSET, code=code)
if isinstance(exp, CallableType):
self.note(
pretty_callable(exp, skip_self=class_obj),
pretty_callable(exp, skip_self=class_obj or is_module),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(exp, Overloaded)
self.pretty_overload(
exp, context, 2 * OFFSET, code=code, skip_self=class_obj
exp, context, 2 * OFFSET, code=code, skip_self=class_obj or is_module
)
self.note("Got:", context, offset=OFFSET, code=code)
if isinstance(got, CallableType):
self.note(
pretty_callable(got, skip_self=class_obj),
pretty_callable(got, skip_self=class_obj or is_module),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(got, Overloaded)
self.pretty_overload(
got, context, 2 * OFFSET, code=code, skip_self=class_obj
got, context, 2 * OFFSET, code=code, skip_self=class_obj or is_module
)
self.print_more(conflict_types, context, OFFSET, MAX_ITEMS, code=code)

Expand Down Expand Up @@ -2147,7 +2147,9 @@ def format_callable_args(
return ", ".join(arg_strings)


def format_type_inner(typ: Type, verbosity: int, fullnames: set[str] | None) -> str:
def format_type_inner(
typ: Type, verbosity: int, fullnames: set[str] | None, module_names: bool = False
) -> str:
"""
Convert a type to a relatively short string suitable for error messages.

Expand Down Expand Up @@ -2187,7 +2189,10 @@ def format_literal_value(typ: LiteralType) -> str:
# Get the short name of the type.
if itype.type.fullname in ("types.ModuleType", "_importlib_modulespec.ModuleType"):
# Make some common error messages simpler and tidier.
return "Module"
base_str = "Module"
if itype.extra_attrs and itype.extra_attrs.mod_name and module_names:
return f"{base_str} {itype.extra_attrs.mod_name}"
return base_str
if verbosity >= 2 or (fullnames and itype.type.fullname in fullnames):
base_str = itype.type.fullname
else:
Expand Down Expand Up @@ -2361,7 +2366,7 @@ def find_type_overlaps(*types: Type) -> set[str]:
return overlaps


def format_type(typ: Type, verbosity: int = 0) -> str:
def format_type(typ: Type, verbosity: int = 0, module_names: bool = False) -> str:
"""
Convert a type to a relatively short string suitable for error messages.

Expand All @@ -2372,10 +2377,10 @@ def format_type(typ: Type, verbosity: int = 0) -> str:
modification of the formatted string is required, callers should use
format_type_bare.
"""
return quote_type_string(format_type_bare(typ, verbosity))
return quote_type_string(format_type_bare(typ, verbosity, module_names))


def format_type_bare(typ: Type, verbosity: int = 0) -> str:
def format_type_bare(typ: Type, verbosity: int = 0, module_names: bool = False) -> str:
"""
Convert a type to a relatively short string suitable for error messages.

Expand All @@ -2387,7 +2392,7 @@ def format_type_bare(typ: Type, verbosity: int = 0) -> str:
instead. (The caller may want to use quote_type_string after
processing has happened, to maintain consistent quoting in messages.)
"""
return format_type_inner(typ, verbosity, find_type_overlaps(typ))
return format_type_inner(typ, verbosity, find_type_overlaps(typ), module_names)


def format_type_distinctly(*types: Type, bare: bool = False) -> tuple[str, ...]:
Expand Down Expand Up @@ -2564,7 +2569,7 @@ def get_conflict_protocol_types(
if not subtype:
continue
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
if IS_SETTABLE in get_member_flags(member, right.type):
if IS_SETTABLE in get_member_flags(member, right):
is_compat = is_compat and is_subtype(supertype, subtype)
if not is_compat:
conflicts.append((member, subtype, supertype))
Expand All @@ -2581,11 +2586,7 @@ def get_bad_protocol_flags(
all_flags: list[tuple[str, set[int], set[int]]] = []
for member in right.type.protocol_members:
if find_member(member, left, left):
item = (
member,
get_member_flags(member, left.type),
get_member_flags(member, right.type),
)
item = (member, get_member_flags(member, left), get_member_flags(member, right))
all_flags.append(item)
bad_flags = []
for name, subflags, superflags in all_flags:
Expand Down
3 changes: 3 additions & 0 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,9 @@ def visit_instance(self, typ: Instance) -> list[str]:
triggers.extend(self.get_type_triggers(arg))
if typ.last_known_value:
triggers.extend(self.get_type_triggers(typ.last_known_value))
if typ.extra_attrs and typ.extra_attrs.mod_name:
# Module as type effectively depends on all module attributes, use wildcard.
triggers.append(make_wildcard_trigger(typ.extra_attrs.mod_name))
return triggers

def visit_type_alias_type(self, typ: TypeAliasType) -> list[str]:
Expand Down
18 changes: 14 additions & 4 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,8 +1010,8 @@ def named_type(fullname: str) -> Instance:
if isinstance(subtype, NoneType) and isinstance(supertype, CallableType):
# We want __hash__ = None idiom to work even without --strict-optional
return False
subflags = get_member_flags(member, left.type, class_obj=class_obj)
superflags = get_member_flags(member, right.type)
subflags = get_member_flags(member, left, class_obj=class_obj)
superflags = get_member_flags(member, right)
if IS_SETTABLE in superflags:
# Check opposite direction for settable attributes.
if not is_subtype(supertype, subtype):
Expand Down Expand Up @@ -1095,10 +1095,12 @@ def find_member(
# PEP 544 doesn't specify anything about such use cases. So we just try
# to do something meaningful (at least we should not crash).
return TypeType(fill_typevars_with_any(v))
if itype.extra_attrs and name in itype.extra_attrs.attrs:
return itype.extra_attrs.attrs[name]
return None


def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[int]:
def get_member_flags(name: str, itype: Instance, class_obj: bool = False) -> set[int]:
"""Detect whether a member 'name' is settable, whether it is an
instance or class variable, and whether it is class or static method.

Expand All @@ -1109,6 +1111,7 @@ def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[
* IS_CLASS_OR_STATIC: set for methods decorated with @classmethod or
with @staticmethod.
"""
info = itype.type
method = info.get_method(name)
setattr_meth = info.get_method("__setattr__")
if method:
Expand All @@ -1126,11 +1129,18 @@ def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[
if not node:
if setattr_meth:
return {IS_SETTABLE}
if itype.extra_attrs and name in itype.extra_attrs.attrs:
flags = set()
if name not in itype.extra_attrs.immutable:
flags.add(IS_SETTABLE)
return flags
return set()
v = node.node
# just a variable
if isinstance(v, Var) and not v.is_property:
flags = {IS_SETTABLE}
flags = set()
if not v.is_final:
flags.add(IS_SETTABLE)
if v.is_classvar:
flags.add(IS_CLASSVAR)
if class_obj and v.is_inferred:
Expand Down
Loading