Skip to content

Prohibit referring to class within its definition #3665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def parse_type_comment(type_comment: str, line: int, errors: Optional[Errors]) -
raise
else:
assert isinstance(typ, ast3.Expression)
return TypeConverter(errors, line=line).visit(typ.body)

# parse_type_comments() is meant to be used on types within strings or comments, so
# there's no need to check if the class is currently being defined or not. It also
# doesn't matter if we're using stub files or not.
return TypeConverter(errors, set(), line=line).visit(typ.body)


def with_line(f: Callable[['ASTConverter', T], U]) -> Callable[['ASTConverter', T], U]:
Expand Down Expand Up @@ -142,7 +146,7 @@ def __init__(self,
options: Options,
is_stub: bool,
errors: Errors) -> None:
self.class_nesting = 0
self.classes_being_defined = [set()] # type: List[Set[str]]
self.imports = [] # type: List[ImportBase]

self.options = options
Expand All @@ -152,6 +156,14 @@ def __init__(self,
def fail(self, msg: str, line: int, column: int) -> None:
self.errors.report(line, column, msg)

def convert_to_type(self, node: ast3.AST, lineno: int, skip_class_check: bool = False) -> Type:
if skip_class_check or self.is_stub:
classes = set() # type: Set[str]
else:
classes = self.classes_being_defined[-1]

return TypeConverter(self.errors, classes, line=lineno).visit(node)

def generic_visit(self, node: ast3.AST) -> None:
raise RuntimeError('AST node not implemented: ' + str(type(node)))

Expand Down Expand Up @@ -254,7 +266,21 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
return ret

def in_class(self) -> bool:
return self.class_nesting > 0
return len(self.classes_being_defined[-1]) > 0

def enter_function_body(self) -> None:
# When defining a method, the body is not processed until
# after the containing class is fully defined, so we reset
# the set of classes being defined since to record that we
# can refer to our parent class directly, without needing
# forward references.
#
# If this is a regular function, not a method, pushing an
# empty set is a harmless no-op.
self.classes_being_defined.append(set())

def leave_function_body(self) -> None:
self.classes_being_defined.pop()

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id.
Expand Down Expand Up @@ -326,12 +352,12 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
# PEP 484 disallows both type annotations and type comments
if n.returns or any(a.type_annotation is not None for a in args):
self.fail(messages.DUPLICATE_TYPE_SIGNATURES, n.lineno, n.col_offset)
translated_args = (TypeConverter(self.errors, line=n.lineno)
translated_args = (TypeConverter(self.errors, set(), line=n.lineno)
.translate_expr_list(func_type_ast.argtypes))
arg_types = [a if a is not None else AnyType()
for a in translated_args]
return_type = TypeConverter(self.errors,
line=n.lineno).visit(func_type_ast.returns)
return_type = TypeConverter(
self.errors, set(), line=n.lineno).visit(func_type_ast.returns)

# add implicit self type
if self.in_class() and len(arg_types) < len(args):
Expand All @@ -342,8 +368,9 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
return_type = AnyType()
else:
arg_types = [a.type_annotation for a in args]
return_type = TypeConverter(self.errors, line=n.returns.lineno
if n.returns else n.lineno).visit(n.returns)
return_type = self.convert_to_type(
n.returns,
n.returns.lineno if n.returns else n.lineno)

for arg, arg_type in zip(args, arg_types):
self.set_type_optional(arg_type, arg.initializer)
Expand All @@ -366,10 +393,13 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
AnyType(implicit=True),
None)

self.enter_function_body()
func_def = FuncDef(n.name,
args,
self.as_block(n.body, n.lineno),
func_type)
self.leave_function_body()

if is_coroutine:
# A coroutine is also a generator, mostly for internal reasons.
func_def.is_generator = func_def.is_coroutine = True
Expand Down Expand Up @@ -410,7 +440,7 @@ def make_argument(arg: ast3.arg, default: Optional[ast3.expr], kind: int) -> Arg
self.fail(messages.DUPLICATE_TYPE_SIGNATURES, arg.lineno, arg.col_offset)
arg_type = None
if arg.annotation is not None:
arg_type = TypeConverter(self.errors, line=arg.lineno).visit(arg.annotation)
arg_type = self.convert_to_type(arg.annotation, arg.lineno)
elif arg.type_comment is not None:
arg_type = parse_type_comment(arg.type_comment, arg.lineno, self.errors)
return Argument(Var(arg.arg), arg_type, self.visit(default), kind)
Expand Down Expand Up @@ -460,7 +490,7 @@ def fail_arg(msg: str, arg: ast3.arg) -> None:
# expr* decorator_list)
@with_line
def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
self.class_nesting += 1
self.classes_being_defined[-1].add(n.name)
metaclass_arg = find(lambda x: x.arg == 'metaclass', n.keywords)
metaclass = None
if metaclass_arg:
Expand All @@ -477,7 +507,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
metaclass=metaclass,
keywords=keywords)
cdef.decorators = self.translate_expr_list(n.decorator_list)
self.class_nesting -= 1
self.classes_being_defined[-1].remove(n.name)
return cdef

# Return(expr? value)
Expand Down Expand Up @@ -513,7 +543,7 @@ def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt:
rvalue = TempNode(AnyType()) # type: Expression
else:
rvalue = self.visit(n.value)
typ = TypeConverter(self.errors, line=n.lineno).visit(n.annotation)
typ = self.convert_to_type(n.annotation, n.lineno)
typ.column = n.annotation.col_offset
return AssignmentStmt([self.visit(n.target)], rvalue, type=typ, new_syntax=True)

Expand Down Expand Up @@ -961,11 +991,18 @@ def visit_Index(self, n: ast3.Index) -> Node:


class TypeConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931
def __init__(self, errors: Errors, line: int = -1) -> None:
def __init__(self,
errors: Errors,
classes_being_defined: Set[str],
line: int = -1) -> None:
self.errors = errors
self.classes_being_defined = classes_being_defined
self.line = line
self.node_stack = [] # type: List[ast3.AST]

def _definition_is_incomplete(self, name: str) -> bool:
return name in self.classes_being_defined

def visit(self, node: ast3.AST) -> Type:
"""Modified visit -- keep track of the stack of nodes"""
self.node_stack.append(node)
Expand Down Expand Up @@ -1049,6 +1086,9 @@ def _extract_argument_name(self, n: ast3.expr) -> str:
return None

def visit_Name(self, n: ast3.Name) -> Type:
if self._definition_is_incomplete(n.id):
self.fail("class '{}' is not fully defined; use a forward reference".format(n.id),
n.lineno, n.col_offset)
return UnboundType(n.id, line=self.line)

def visit_NameConstant(self, n: ast3.NameConstant) -> Type:
Expand Down
4 changes: 2 additions & 2 deletions mypy/fastparse2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile:
# arg? kwarg, expr* defaults)
@with_line
def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
converter = TypeConverter(self.errors, line=n.lineno)
converter = TypeConverter(self.errors, set(), line=n.lineno)
args, decompose_stmts = self.transform_args(n.args, n.lineno)

arg_kinds = [arg.kind for arg in args]
Expand Down Expand Up @@ -378,7 +378,7 @@ def transform_args(self,
# TODO: remove the cast once https://github.com/python/typeshed/pull/522
# is accepted and synced
type_comments = cast(List[str], n.type_comments) # type: ignore
converter = TypeConverter(self.errors, line=line)
converter = TypeConverter(self.errors, set(), line=line)
decompose_stmts = [] # type: List[Statement]

def extract_names(arg: ast27.expr) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-class-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class XRepr(NamedTuple):
y: int = 1
def __str__(self) -> str:
return 'string'
def __add__(self, other: XRepr) -> int:
def __add__(self, other: 'XRepr') -> int:
return 0

reveal_type(XMeth(1).double()) # E: Revealed type is 'builtins.int'
Expand Down
28 changes: 14 additions & 14 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1477,24 +1477,24 @@ from typing import Any
def deco(f: Any) -> Any: return f
class C:
@deco
def __add__(self, other: C) -> C: return C()
def __radd__(self, other: C) -> C: return C()
def __add__(self, other: 'C') -> 'C': return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]

[case testReverseOperatorMethodForwardIsAny2]
from typing import Any
def deco(f: Any) -> Any: return f
class C:
__add__ = None # type: Any
def __radd__(self, other: C) -> C: return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]

[case testReverseOperatorMethodForwardIsAny3]
from typing import Any
def deco(f: Any) -> Any: return f
class C:
__add__ = 42
def __radd__(self, other: C) -> C: return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]
main:5: error: Forward operator "__add__" is not callable

Expand Down Expand Up @@ -1631,7 +1631,7 @@ main:8: error: Signatures of "__iadd__" and "__add__" are incompatible

a, b = None, None # type: A, B
class A:
def __getattribute__(self, x: str) -> A:
def __getattribute__(self, x: str) -> 'A':
return A()
class B: pass

Expand All @@ -1642,11 +1642,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab

[case testGetAttributeSignature]
class A:
def __getattribute__(self, x: str) -> A: pass
def __getattribute__(self, x: str) -> 'A': pass
class B:
def __getattribute__(self, x: A) -> B: pass
def __getattribute__(self, x: A) -> 'B': pass
class C:
def __getattribute__(self, x: str, y: str) -> C: pass
def __getattribute__(self, x: str, y: str) -> 'C': pass
class D:
def __getattribute__(self, x: str) -> None: pass
[out]
Expand All @@ -1657,7 +1657,7 @@ main:6: error: Invalid signature "def (__main__.C, builtins.str, builtins.str) -

a, b = None, None # type: A, B
class A:
def __getattr__(self, x: str) -> A:
def __getattr__(self, x: str) -> 'A':
return A()
class B: pass

Expand All @@ -1668,11 +1668,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab

[case testGetAttrSignature]
class A:
def __getattr__(self, x: str) -> A: pass
def __getattr__(self, x: str) -> 'A': pass
class B:
def __getattr__(self, x: A) -> B: pass
def __getattr__(self, x: A) -> 'B': pass
class C:
def __getattr__(self, x: str, y: str) -> C: pass
def __getattr__(self, x: str, y: str) -> 'C': pass
class D:
def __getattr__(self, x: str) -> None: pass
[out]
Expand Down Expand Up @@ -1776,7 +1776,7 @@ a = a(b) # E: Argument 1 to "__call__" of "A" has incompatible type "B"; expect
b = a(a) # E: Incompatible types in assignment (expression has type "A", variable has type "B")

class A:
def __call__(self, x: A) -> A:
def __call__(self, x: 'A') -> 'A':
pass
class B: pass

Expand Down Expand Up @@ -3280,7 +3280,7 @@ def r(ta: Type[TA], tta: TTA) -> None:

class Class(metaclass=M):
@classmethod
def f1(cls: Type[Class]) -> None: pass
def f1(cls: Type['Class']) -> None: pass
@classmethod
def f2(cls: M) -> None: pass
cl: Type[Class] = m # E: Incompatible types in assignment (expression has type "M", variable has type Type[Class])
Expand Down
16 changes: 8 additions & 8 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class A:
pass

class C(A):
def copy(self: C) -> C:
def copy(self: 'C') -> 'C':
pass

class D(A):
Expand Down Expand Up @@ -276,10 +276,10 @@ class B:
return cls()

class C:
def foo(self: C) -> C: return self
def foo(self: 'C') -> 'C': return self

@classmethod
def cfoo(cls: Type[C]) -> C:
def cfoo(cls: Type['C']) -> 'C':
return cls()

class D:
Expand Down Expand Up @@ -330,21 +330,21 @@ class B:
pass

class C:
def __new__(cls: Type[C]) -> C:
def __new__(cls: Type['C']) -> 'C':
return cls()

def __init_subclass__(cls: Type[C]) -> None:
def __init_subclass__(cls: Type['C']) -> None:
pass

class D:
def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
def __new__(cls: 'D') -> 'D': # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
return cls

def __init_subclass__(cls: D) -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
def __init_subclass__(cls: 'D') -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
pass

class E:
def __new__(cls) -> E:
def __new__(cls) -> 'E':
reveal_type(cls) # E: Revealed type is 'def () -> __main__.E'
return cls()

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class A(object):
self.a = 0

def __iadd__(self, a):
# type: (int) -> A
# type: (int) -> 'A'
self.a += 1
return self

Expand Down
18 changes: 9 additions & 9 deletions test-data/unit/check-typevar-values.test
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ f(S())
[case testCheckGenericFunctionBodyWithTypeVarValues]
from typing import TypeVar
class A:
def f(self, x: int) -> A: return self
def f(self, x: int) -> 'A': return self
class B:
def f(self, x: int) -> B: return self
def f(self, x: int) -> 'B': return self
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
x = x.f(1)
Expand All @@ -58,11 +58,11 @@ def f(x: AB) -> AB:
[case testCheckGenericFunctionBodyWithTypeVarValues2]
from typing import TypeVar
class A:
def f(self) -> A: return A()
def g(self) -> B: return B()
def f(self) -> 'A': return A()
def g(self) -> 'B': return B()
class B:
def f(self) -> A: return A()
def g(self) -> B: return B()
def g(self) -> 'B': return B()
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
return x.f() # Error
Expand All @@ -75,11 +75,11 @@ main:12: error: Incompatible return value type (got "B", expected "A")
[case testTypeInferenceAndTypeVarValues]
from typing import TypeVar
class A:
def f(self) -> A: return self
def g(self) -> B: return B()
def f(self) -> 'A': return self
def g(self) -> 'B': return B()
class B:
def f(self) -> B: return self
def g(self) -> B: return B()
def f(self) -> 'B': return self
def g(self) -> 'B': return B()
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
y = x
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class A:
def g(self) -> None: pass
[file m.py.2]
class A:
def g(self, a: A) -> None: pass
def g(self, a: 'A') -> None: pass
[out]
==
main:4: error: Too few arguments for "g" of "A"
Expand Down
Loading