diff --git a/mypy/checker.py b/mypy/checker.py index 72c53f0500ab..707ead0641cf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -21,8 +21,8 @@ Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr, Import, ImportFrom, ImportAll, ImportBase, TypeAlias, - ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, - CONTRAVARIANT, COVARIANT, INVARIANT, + ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, CallableDecorator, + CONTRAVARIANT, COVARIANT, INVARIANT, get_callable ) from mypy import nodes from mypy.literals import literal, literal_hash @@ -1619,11 +1619,15 @@ def check_compatibility(self, name: str, base1: TypeInfo, first = base1[name] second = base2[name] first_type = first.type - if first_type is None and isinstance(first.node, FuncBase): - first_type = self.function_type(first.node) + if first_type is None: + method = get_callable(first.node) + if method: + first_type = self.function_type(method) second_type = second.type - if second_type is None and isinstance(second.node, FuncBase): - second_type = self.function_type(second.node) + if second_type is None: + method = get_callable(second.node) + if method: + second_type = self.function_type(method) # TODO: What if some classes are generic? if (isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike)): @@ -3019,10 +3023,17 @@ def visit_decorator(self, e: Decorator) -> None: callable_name=fullname) self.check_untyped_after_decorator(sig, e.func) sig = set_callable_name(sig, e.func) - e.var.type = sig - e.var.is_ready = True if e.func.is_property: self.check_incompatible_property_override(e) + e.var.type = sig + e.var.is_ready = True + if isinstance(sig, CallableType): + if e.func.is_property: + assert isinstance(sig, CallableType) + if isinstance(sig.ret_type, CallableType): + e.callable_decorator = CallableDecorator(e) + else: + e.callable_decorator = CallableDecorator(e) if e.func.info and not e.func.is_dynamic(): self.check_method_override(e) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 5a77ee9e1402..810bc3be974c 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -77,7 +77,7 @@ def analyze_member_access(name: str, # Look up the member. First look up the method dictionary. method = info.get_method(name) - if method: + if method and not method.is_class: if method.is_property: assert isinstance(method, OverloadedFuncDef) first_item = cast(Decorator, method.items[0]) @@ -87,7 +87,7 @@ def analyze_member_access(name: str, msg.cant_assign_to_method(node) signature = function_type(method, builtin_type('builtins.function')) signature = freshen_function_type_vars(signature) - if name == '__new__': + if name == '__new__' or method.is_static: # __new__ is special and behaves like a static method -- don't strip # the first argument. pass diff --git a/mypy/nodes.py b/mypy/nodes.py index a6c59470037c..87b67ab67d15 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -657,6 +657,7 @@ class Decorator(SymbolNode, Statement): # TODO: This is mostly used for the type; consider replacing with a 'type' attribute var = None # type: Var # Represents the decorated function obj is_overload = False + callable_decorator = None # type: Optional[CallableDecorator] def __init__(self, func: FuncDef, decorators: List[Expression], var: 'Var') -> None: @@ -704,6 +705,28 @@ def deserialize(cls, data: JsonDict) -> 'Decorator': return dec +class CallableDecorator(FuncItem): + """A wrapper around a Decorator that allows it to be treated as a callable function""" + def __init__(self, decorator: Decorator) -> None: + super().__init__(decorator.func.arguments, decorator.func.body, + cast('mypy.types.CallableType', decorator.type)) + self.is_final = decorator.is_final + self.is_class = decorator.func.is_class + self.is_property = decorator.func.is_property + self.is_static = decorator.func.is_static + self.is_overload = decorator.func.is_overload + self.is_generator = decorator.func.is_generator + self.is_async_generator = decorator.func.is_async_generator + self.is_awaitable_coroutine = decorator.func.is_awaitable_coroutine + self.expanded = decorator.func.expanded + self.info = decorator.info + self._name = decorator.func.name() + self._fullname = decorator.func._fullname + + def name(self) -> str: + return self._name + + VAR_FLAGS = [ 'is_self', 'is_initialized_in_class', 'is_staticmethod', 'is_classmethod', 'is_property', 'is_settable_property', 'is_suppressed_import', @@ -2308,11 +2331,7 @@ def has_readable_member(self, name: str) -> bool: def get_method(self, name: str) -> Optional[FuncBase]: for cls in self.mro: if name in cls.names: - node = cls.names[name].node - if isinstance(node, FuncBase): - return node - else: - return None + return get_callable(cls.names[name].node) return None def calculate_metaclass_type(self) -> 'Optional[mypy.types.Instance]': @@ -2935,3 +2954,13 @@ def is_class_var(expr: NameExpr) -> bool: if isinstance(expr.node, Var): return expr.node.is_classvar return False + + +def get_callable(node: Optional[Node]) -> Optional[FuncBase]: + """Check if the passed node represents a callable function or funcion-like object""" + if isinstance(node, FuncBase): + return node + elif isinstance(node, Decorator): + return node.callable_decorator + else: + return None diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 8329a87facb2..5b43860bce3b 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -7,9 +7,9 @@ from mypy.fixup import lookup_qualified_stnode from mypy.nodes import ( Context, Argument, Var, ARG_OPT, ARG_POS, TypeInfo, AssignmentStmt, - TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, FuncBase, - is_class_var, TempNode, Decorator, MemberExpr, Expression, FuncDef, Block, - PassStmt, SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef + TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, is_class_var, + TempNode, Decorator, MemberExpr, Expression, FuncDef, Block, + PassStmt, SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, get_callable ) from mypy.plugins.common import ( _get_argument, _get_bool_argument, _get_decorator_bool_argument @@ -405,9 +405,8 @@ def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', # TODO: Support complex converters, e.g. lambdas, calls, etc. if converter: if isinstance(converter, RefExpr) and converter.node: - if (isinstance(converter.node, FuncBase) - and converter.node.type - and isinstance(converter.node.type, FunctionLike)): + method = get_callable(converter.node) + if method and method.type and isinstance(method.type, FunctionLike): return Converter(converter.node.fullname()) elif isinstance(converter.node, TypeInfo): return Converter(converter.node.fullname()) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index e483e69ff348..7786b0381ee6 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -1,8 +1,8 @@ from typing import List, Optional, Any from mypy.nodes import ( - ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase, - FuncDef, PassStmt, RefExpr, SymbolTableNode, Var + ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncDef, + PassStmt, RefExpr, SymbolTableNode, Var, get_callable ) from mypy.plugin import ClassDefContext from mypy.semanal import set_callable_name @@ -53,8 +53,9 @@ def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: callee_type = None # mypyc hack to workaround mypy misunderstanding multiple inheritance (#3603) callee_node = call.callee.node # type: Any - if (isinstance(callee_node, (Var, FuncBase)) - and callee_node.type): + if not isinstance(callee_node, Var): + callee_node = get_callable(callee_node) + if callee_node and callee_node.type: callee_node_type = callee_node.type if isinstance(callee_node_type, Overloaded): # We take the last overload. diff --git a/mypy/server/deps.py b/mypy/server/deps.py index ec4659c0e8aa..69d230e24e07 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -91,9 +91,9 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr, ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt, TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block, - TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr, + TypeInfo, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr, LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr, - op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods + op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods, get_callable ) from mypy.traverser import TraverserVisitor from mypy.types import ( @@ -127,6 +127,7 @@ def get_dependencies_of_target(module_id: str, # TODO: Add tests for this function. visitor = DependencyVisitor(type_map, python_version, module_tree.alias_deps) visitor.scope.enter_file(module_id) + method = get_callable(target) if isinstance(target, MypyFile): # Only get dependencies of the top-level of the module. Don't recurse into # functions. @@ -134,10 +135,10 @@ def get_dependencies_of_target(module_id: str, # TODO: Recurse into top-level statements and class bodies but skip functions. if not isinstance(defn, (ClassDef, Decorator, FuncDef, OverloadedFuncDef)): defn.accept(visitor) - elif isinstance(target, FuncBase) and target.info: + elif method and method.info: # It's a method. # TODO: Methods in nested classes. - visitor.scope.enter_class(target.info) + visitor.scope.enter_class(method.info) target.accept(visitor) visitor.scope.leave() else: @@ -425,8 +426,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: if isinstance(rvalue.callee.node, TypeInfo): # use actual __init__ as a dependency source init = rvalue.callee.node.get('__init__') - if init and isinstance(init.node, FuncBase): - fname = init.node.fullname() + if init: + method = get_callable(init.node) + if method: + fname = method.fullname() else: fname = rvalue.callee.fullname if fname is None: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index d06ad40fdfa1..3ff82d206bde 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -572,7 +572,10 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]: assert isinstance(dec, Decorator) if dec.var.is_settable_property or setattr_meth: return {IS_SETTABLE} - return set() + if method.is_static or method.is_class: + return {IS_CLASS_OR_STATIC} + else: + return set() node = info.get(name) if not node: if setattr_meth: @@ -604,7 +607,12 @@ def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) - if typ is None: return AnyType(TypeOfAny.from_error) # We don't need to bind 'self' for static methods, since there is no 'self'. - if isinstance(node, FuncBase) or isinstance(typ, FunctionLike) and not node.is_staticmethod: + need_bind = False + if isinstance(node, FuncBase): + need_bind = not node.is_static + elif isinstance(typ, FunctionLike): + need_bind = not node.is_staticmethod + if need_bind: assert isinstance(typ, FunctionLike) signature = bind_self(typ, subtype) if node.is_property: diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 7784f26df4b2..48af6fd1a646 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -5610,3 +5610,29 @@ from typing import TypeVar, Tuple, Callable T = TypeVar('T') def deco(f: Callable[..., T]) -> Callable[..., Tuple[T, int]]: ... [out] + +[case testDecoratedInit] +from typing import Callable, Any +def dec(func: Callable[[Any], None]) -> Callable[[Any], None]: + return func + +class A: + @dec + def __init__(self): + pass + +reveal_type(A()) # E: Revealed type is '__main__.A' +[out] + +[case testAbstractInit] +from abc import abstractmethod +class A: + @abstractmethod + def __init__(self): ... + +class B(A): + def __init__(self): + pass + +reveal_type(B()) # E: Revealed type is '__main__.B' +[out] diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 876a5fe43f0c..13b51ef4ec43 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -2383,3 +2383,16 @@ def foo() -> None: def lol(): x = foo() + + +[case testNonCallableDecorator] +def dec(func) -> int: + return 1 + +@dec +def f(): + pass + +reveal_type(f) # E: Revealed type is 'builtins.int' +f() # E: "int" not callable +[out]