Skip to content

Make overloads support classmethod and staticmethod #5224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,12 +1289,10 @@ def check_override(self, override: FunctionLike, original: FunctionLike,
# this could be unsafe with reverse operator methods.
fail = True

if isinstance(original, CallableType) and isinstance(override, CallableType):
if (isinstance(original.definition, FuncItem) and
isinstance(override.definition, FuncItem)):
if ((original.definition.is_static or original.definition.is_class) and
not (override.definition.is_static or override.definition.is_class)):
fail = True
if isinstance(original, FunctionLike) and isinstance(override, FunctionLike):
if ((original.is_classmethod() or original.is_staticmethod()) and
not (override.is_classmethod() or override.is_staticmethod())):
fail = True

if fail:
emitted_msg = False
Expand Down Expand Up @@ -3923,8 +3921,6 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool:
def is_static(func: Union[FuncBase, Decorator]) -> bool:
if isinstance(func, Decorator):
return is_static(func.func)
elif isinstance(func, OverloadedFuncDef):
return any(is_static(item) for item in func.items)
elif isinstance(func, FuncItem):
elif isinstance(func, FuncBase):
return func.is_static
return False
assert False, "Unexpected func type: {}".format(type(func))
3 changes: 2 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def analyze_class_attribute_access(itype: Instance,
return handle_partial_attribute_type(t, is_lvalue, msg, symnode)
if not is_method and (isinstance(t, TypeVarType) or get_type_vars(t)):
msg.fail(messages.GENERIC_INSTANCE_VAR_CLASS_ACCESS, context)
is_classmethod = is_decorated and cast(Decorator, node.node).func.is_class
is_classmethod = ((is_decorated and cast(Decorator, node.node).func.is_class)
or (isinstance(node.node, FuncBase) and node.node.is_class))
return add_class_tvars(t, itype, is_classmethod, builtin_type, original_type)
elif isinstance(node.node, Var):
not_ready_callback(name, context)
Expand Down
8 changes: 7 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
CallExpr, Expression
CallExpr, Expression, OverloadedFuncDef,
)

# Constants that represent simple type checker error message, i.e. messages
Expand Down Expand Up @@ -942,6 +942,12 @@ def incompatible_typevar_value(self,
self.format(typ)),
context)

def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
self.fail(
'Overload does not consistently use the "@{}" '.format(decorator)
+ 'decorator on all function signatures.',
context)

def overloaded_signatures_overlap(self, index1: int, index2: int, context: Context) -> None:
self.fail('Overloaded function signatures {} and {} overlap with '
'incompatible return types'.format(index1, index2), context)
Expand Down
27 changes: 15 additions & 12 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,20 @@ def __str__(self) -> str:
return 'ImportedName(%s)' % self.target_fullname


FUNCBASE_FLAGS = [
'is_property', 'is_class', 'is_static',
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also update astdiff.py according to these flag reshuffling? This may break fine grained incremental mode (in some corner cases).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, good point -- I didn't even know that file existed.

I tried making the change + tried adding an test to one of the fine-grained incremental tests. (I'm pretty unfamiliar with fine-grained incremental stuff though, so let me know if I did it incorrectly.)



class FuncBase(Node):
"""Abstract base class for function-like nodes"""

__slots__ = ('type',
'unanalyzed_type',
'info',
'is_property',
'is_class', # Uses "@classmethod"
'is_static', # USes "@staticmethod"
'_fullname',
)

Expand All @@ -391,6 +398,8 @@ def __init__(self) -> None:
# TODO: Type should be Optional[TypeInfo]
self.info = cast(TypeInfo, None)
self.is_property = False
self.is_class = False
self.is_static = False
# Name with module prefix
# TODO: Type should be Optional[str]
self._fullname = cast(str, None)
Expand Down Expand Up @@ -436,8 +445,8 @@ def serialize(self) -> JsonDict:
'items': [i.serialize() for i in self.items],
'type': None if self.type is None else self.type.serialize(),
'fullname': self._fullname,
'is_property': self.is_property,
'impl': None if self.impl is None else self.impl.serialize()
'impl': None if self.impl is None else self.impl.serialize(),
'flags': get_flags(self, FUNCBASE_FLAGS),
}

@classmethod
Expand All @@ -451,7 +460,7 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef':
if data.get('type') is not None:
res.type = mypy.types.deserialize_type(data['type'])
res._fullname = data['fullname']
res.is_property = data['is_property']
set_flags(res, data['flags'])
# NOTE: res.info will be set in the fixup phase.
return res

Expand Down Expand Up @@ -481,9 +490,9 @@ def set_line(self, target: Union[Context, int], column: Optional[int] = None) ->
self.variable.set_line(self.line, self.column)


FUNCITEM_FLAGS = [
FUNCITEM_FLAGS = FUNCBASE_FLAGS + [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds is_property, is this intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. I made this change partly on principle since is_property actually is a field of FuncBase -- it felt cleaner to just force all subclasses to preserve that field no matter what.

This change also doesn't actually change the serialized output in practice. FuncItem currently has only two subtypes: FuncDef and LambdaExpr. The former subclass previously explicitly set and serialized the is_property field so this change makes no difference there. The latter subclass never really uses is_property but also doesn't have any serialize/deserialize methods, which makes this change moot.

'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator',
'is_awaitable_coroutine', 'is_static', 'is_class',
'is_awaitable_coroutine',
]


Expand All @@ -503,8 +512,6 @@ class FuncItem(FuncBase):
'is_coroutine', # Defined using 'async def' syntax?
'is_async_generator', # Is an async def generator?
'is_awaitable_coroutine', # Decorated with '@{typing,asyncio}.coroutine'?
'is_static', # Uses @staticmethod?
'is_class', # Uses @classmethod?
'expanded', # Variants of function with type variables with values expanded
)

Expand All @@ -525,8 +532,6 @@ def __init__(self,
self.is_coroutine = False
self.is_async_generator = False
self.is_awaitable_coroutine = False
self.is_static = False
self.is_class = False
self.expanded = [] # type: List[FuncItem]

self.min_args = 0
Expand All @@ -547,7 +552,7 @@ def is_dynamic(self) -> bool:


FUNCDEF_FLAGS = FUNCITEM_FLAGS + [
'is_decorated', 'is_conditional', 'is_abstract', 'is_property',
'is_decorated', 'is_conditional', 'is_abstract',
]


Expand All @@ -561,7 +566,6 @@ class FuncDef(FuncItem, SymbolNode, Statement):
'is_decorated',
'is_conditional',
'is_abstract',
'is_property',
'original_def',
)

Expand All @@ -575,7 +579,6 @@ def __init__(self,
self.is_decorated = False
self.is_conditional = False # Defined conditionally (within block)?
self.is_abstract = False
self.is_property = False
# Original conditional definition
self.original_def = None # type: Union[None, FuncDef, Var, Decorator]

Expand Down
10 changes: 10 additions & 0 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,16 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute],
func_type = stmt.func.type
if isinstance(func_type, CallableType):
func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info)
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
func_type = stmt.type
if isinstance(func_type, Overloaded):
class_type = ctx.api.class_type(ctx.cls.info)
for item in func_type.items():
item.arg_types[0] = class_type
if stmt.impl is not None:
assert isinstance(stmt.impl, Decorator)
if isinstance(stmt.impl.func.type, CallableType):
stmt.impl.func.type.arg_types[0] = class_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good that you also take care about plugins.



class MethodAdder:
Expand Down
14 changes: 12 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Decorator, Expression, FuncDef, JsonDict, NameExpr,
SymbolTableNode, TempNode, TypeInfo, Var,
OverloadedFuncDef, SymbolTableNode, TempNode, TypeInfo, Var,
)
from mypy.plugin import ClassDefContext
from mypy.plugins.common import _add_method, _get_decorator_bool_argument
from mypy.types import (
CallableType, Instance, NoneTyp, TypeVarDef, TypeVarType,
CallableType, Instance, NoneTyp, Overloaded, TypeVarDef, TypeVarType,
)

# The set of decorators that generate dataclasses.
Expand Down Expand Up @@ -95,6 +95,16 @@ def transform(self) -> None:
func_type = stmt.func.type
if isinstance(func_type, CallableType):
func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info)
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
func_type = stmt.type
if isinstance(func_type, Overloaded):
class_type = ctx.api.class_type(ctx.cls.info)
for item in func_type.items():
item.arg_types[0] = class_type
if stmt.impl is not None:
assert isinstance(stmt.impl, Decorator)
if isinstance(stmt.impl.func.type, CallableType):
stmt.impl.func.type.arg_types[0] = class_type

# Add an eq method, but only if the class doesn't already have one.
if decorator_arguments['eq'] and info.get('__eq__') is None:
Expand Down
31 changes: 31 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,37 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
# redefinitions already.
return

# We know this is an overload def -- let's handle classmethod and staticmethod
class_status = []
static_status = []
for item in defn.items:
if isinstance(item, Decorator):
inner = item.func
elif isinstance(item, FuncDef):
inner = item
else:
assert False, "The 'item' variable is an unexpected type: {}".format(type(item))
class_status.append(inner.is_class)
static_status.append(inner.is_static)

if defn.impl is not None:
if isinstance(defn.impl, Decorator):
inner = defn.impl.func
elif isinstance(defn.impl, FuncDef):
inner = defn.impl
else:
assert False, "Unexpected impl type: {}".format(type(defn.impl))
class_status.append(inner.is_class)
static_status.append(inner.is_static)

if len(set(class_status)) != 1:
self.msg.overload_inconsistently_applies_decorator('classmethod', defn)
elif len(set(static_status)) != 1:
self.msg.overload_inconsistently_applies_decorator('staticmethod', defn)
else:
defn.is_class = class_status[0]
defn.is_static = static_status[0]

if self.type and not self.is_func_scope():
self.type.names[defn.name()] = SymbolTableNode(MDEF, defn,
typ=defn.type)
Expand Down
6 changes: 3 additions & 3 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'

from mypy.nodes import (
SymbolTable, TypeInfo, Var, SymbolNode, Decorator, TypeVarExpr,
OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
FuncBase, OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
)
from mypy.types import (
Type, TypeVisitor, UnboundType, AnyType, NoneTyp, UninhabitedType,
Expand Down Expand Up @@ -167,13 +167,13 @@ def snapshot_definition(node: Optional[SymbolNode],
The representation is nested tuples and dicts. Only externally
visible attributes are included.
"""
if isinstance(node, (OverloadedFuncDef, FuncItem)):
if isinstance(node, FuncBase):
# TODO: info
if node.type:
signature = snapshot_type(node.type)
else:
signature = snapshot_untyped_signature(node)
return ('Func', common, node.is_property, signature)
return ('Func', common, node.is_property, node.is_class, node.is_static, signature)
elif isinstance(node, Var):
return ('Var', common, snapshot_optional_type(node.type))
elif isinstance(node, Decorator):
Expand Down
4 changes: 4 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> str:
a.insert(0, o.type)
if o.impl:
a.insert(0, o.impl)
if o.is_static:
a.insert(-1, 'Static')
if o.is_class:
a.insert(-1, 'Class')
return self.dump(a, o)

def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> str:
Expand Down
3 changes: 3 additions & 0 deletions mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe
new._fullname = node._fullname
new.type = self.optional_type(node.type)
new.info = node.info
new.is_static = node.is_static
new.is_class = node.is_class
new.is_property = node.is_property
if node.impl:
new.impl = cast(OverloadPart, node.impl.accept(self))
return new
Expand Down
20 changes: 19 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy import experiments
from mypy.nodes import (
INVARIANT, SymbolNode, ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT,
FuncDef
FuncBase, FuncDef,
)
from mypy.sharedparse import argument_elide_name
from mypy.util import IdMapper
Expand Down Expand Up @@ -645,6 +645,12 @@ def with_name(self, name: str) -> 'FunctionLike': pass
@abstractmethod
def get_name(self) -> Optional[str]: pass

@abstractmethod
def is_classmethod(self) -> bool: pass

@abstractmethod
def is_staticmethod(self) -> bool: pass


FormalArgument = NamedTuple('FormalArgument', [
('name', Optional[str]),
Expand Down Expand Up @@ -803,6 +809,12 @@ def with_name(self, name: str) -> 'CallableType':
def get_name(self) -> Optional[str]:
return self.name

def is_classmethod(self) -> bool:
return isinstance(self.definition, FuncBase) and self.definition.is_class

def is_staticmethod(self) -> bool:
return isinstance(self.definition, FuncBase) and self.definition.is_static

def max_fixed_args(self) -> int:
n = len(self.arg_types)
if self.is_var_arg:
Expand Down Expand Up @@ -1000,6 +1012,12 @@ def with_name(self, name: str) -> 'Overloaded':
def get_name(self) -> Optional[str]:
return self._items[0].name

def is_classmethod(self) -> bool:
return self._items[0].is_classmethod()

def is_staticmethod(self) -> bool:
return self._items[0].is_staticmethod()

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_overloaded(self)

Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,38 @@ a = A.new()
reveal_type(a.foo) # E: Revealed type is 'def () -> builtins.int'
[builtins fixtures/classmethod.pyi]

[case testAttrsOtherOverloads]
import attr
from typing import overload, Union

@attr.s
class A:
a = attr.ib()
b = attr.ib(default=3)

@classmethod
def other(cls) -> str:
return "..."

@overload
@classmethod
def foo(cls, x: int) -> int: ...

@overload
@classmethod
def foo(cls, x: str) -> str: ...

@classmethod
def foo(cls, x: Union[int, str]) -> Union[int, str]:
reveal_type(cls) # E: Revealed type is 'def (a: Any, b: Any =) -> __main__.A'
reveal_type(cls.other()) # E: Revealed type is 'builtins.str'
return x

reveal_type(A.foo(3)) # E: Revealed type is 'builtins.int'
reveal_type(A.foo("foo")) # E: Revealed type is 'builtins.str'

[builtins fixtures/classmethod.pyi]

[case testAttrsDefaultDecorator]
import attr
@attr.s
Expand Down
Loading