diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 58835c6de810..298b5327e73e 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -109,7 +109,11 @@ def parse_type_comment(type_comment: str, line: int, errors: Optional[Errors]) - raise else: assert isinstance(typ, ast3.Expression) - return TypeConverter(errors, line=line).visit(typ.body) + + # parse_type_comments() is meant to be used on types within strings or comments, so + # there's no need to check if the class is currently being defined or not. It also + # doesn't matter if we're using stub files or not. + return TypeConverter(errors, set(), line=line).visit(typ.body) def with_line(f: Callable[['ASTConverter', T], U]) -> Callable[['ASTConverter', T], U]: @@ -142,7 +146,7 @@ def __init__(self, options: Options, is_stub: bool, errors: Errors) -> None: - self.class_nesting = 0 + self.classes_being_defined = [set()] # type: List[Set[str]] self.imports = [] # type: List[ImportBase] self.options = options @@ -152,6 +156,14 @@ def __init__(self, def fail(self, msg: str, line: int, column: int) -> None: self.errors.report(line, column, msg) + def convert_to_type(self, node: ast3.AST, lineno: int, skip_class_check: bool = False) -> Type: + if skip_class_check or self.is_stub: + classes = set() # type: Set[str] + else: + classes = self.classes_being_defined[-1] + + return TypeConverter(self.errors, classes, line=lineno).visit(node) + def generic_visit(self, node: ast3.AST) -> None: raise RuntimeError('AST node not implemented: ' + str(type(node))) @@ -254,7 +266,21 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: return ret def in_class(self) -> bool: - return self.class_nesting > 0 + return len(self.classes_being_defined[-1]) > 0 + + def enter_function_body(self) -> None: + # When defining a method, the body is not processed until + # after the containing class is fully defined, so we reset + # the set of classes being defined since to record that we + # can refer to our parent class directly, without needing + # forward references. + # + # If this is a regular function, not a method, pushing an + # empty set is a harmless no-op. + self.classes_being_defined.append(set()) + + def leave_function_body(self) -> None: + self.classes_being_defined.pop() def translate_module_id(self, id: str) -> str: """Return the actual, internal module id for a source text id. @@ -326,12 +352,12 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], # PEP 484 disallows both type annotations and type comments if n.returns or any(a.type_annotation is not None for a in args): self.fail(messages.DUPLICATE_TYPE_SIGNATURES, n.lineno, n.col_offset) - translated_args = (TypeConverter(self.errors, line=n.lineno) + translated_args = (TypeConverter(self.errors, set(), line=n.lineno) .translate_expr_list(func_type_ast.argtypes)) arg_types = [a if a is not None else AnyType() for a in translated_args] - return_type = TypeConverter(self.errors, - line=n.lineno).visit(func_type_ast.returns) + return_type = TypeConverter( + self.errors, set(), line=n.lineno).visit(func_type_ast.returns) # add implicit self type if self.in_class() and len(arg_types) < len(args): @@ -342,8 +368,9 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], return_type = AnyType() else: arg_types = [a.type_annotation for a in args] - return_type = TypeConverter(self.errors, line=n.returns.lineno - if n.returns else n.lineno).visit(n.returns) + return_type = self.convert_to_type( + n.returns, + n.returns.lineno if n.returns else n.lineno) for arg, arg_type in zip(args, arg_types): self.set_type_optional(arg_type, arg.initializer) @@ -366,10 +393,13 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], AnyType(implicit=True), None) + self.enter_function_body() func_def = FuncDef(n.name, args, self.as_block(n.body, n.lineno), func_type) + self.leave_function_body() + if is_coroutine: # A coroutine is also a generator, mostly for internal reasons. func_def.is_generator = func_def.is_coroutine = True @@ -410,7 +440,7 @@ def make_argument(arg: ast3.arg, default: Optional[ast3.expr], kind: int) -> Arg self.fail(messages.DUPLICATE_TYPE_SIGNATURES, arg.lineno, arg.col_offset) arg_type = None if arg.annotation is not None: - arg_type = TypeConverter(self.errors, line=arg.lineno).visit(arg.annotation) + arg_type = self.convert_to_type(arg.annotation, arg.lineno) elif arg.type_comment is not None: arg_type = parse_type_comment(arg.type_comment, arg.lineno, self.errors) return Argument(Var(arg.arg), arg_type, self.visit(default), kind) @@ -460,7 +490,7 @@ def fail_arg(msg: str, arg: ast3.arg) -> None: # expr* decorator_list) @with_line def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: - self.class_nesting += 1 + self.classes_being_defined[-1].add(n.name) metaclass_arg = find(lambda x: x.arg == 'metaclass', n.keywords) metaclass = None if metaclass_arg: @@ -477,7 +507,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: metaclass=metaclass, keywords=keywords) cdef.decorators = self.translate_expr_list(n.decorator_list) - self.class_nesting -= 1 + self.classes_being_defined[-1].remove(n.name) return cdef # Return(expr? value) @@ -513,7 +543,7 @@ def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt: rvalue = TempNode(AnyType()) # type: Expression else: rvalue = self.visit(n.value) - typ = TypeConverter(self.errors, line=n.lineno).visit(n.annotation) + typ = self.convert_to_type(n.annotation, n.lineno) typ.column = n.annotation.col_offset return AssignmentStmt([self.visit(n.target)], rvalue, type=typ, new_syntax=True) @@ -961,11 +991,18 @@ def visit_Index(self, n: ast3.Index) -> Node: class TypeConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931 - def __init__(self, errors: Errors, line: int = -1) -> None: + def __init__(self, + errors: Errors, + classes_being_defined: Set[str], + line: int = -1) -> None: self.errors = errors + self.classes_being_defined = classes_being_defined self.line = line self.node_stack = [] # type: List[ast3.AST] + def _definition_is_incomplete(self, name: str) -> bool: + return name in self.classes_being_defined + def visit(self, node: ast3.AST) -> Type: """Modified visit -- keep track of the stack of nodes""" self.node_stack.append(node) @@ -1049,6 +1086,9 @@ def _extract_argument_name(self, n: ast3.expr) -> str: return None def visit_Name(self, n: ast3.Name) -> Type: + if self._definition_is_incomplete(n.id): + self.fail("class '{}' is not fully defined; use a forward reference".format(n.id), + n.lineno, n.col_offset) return UnboundType(n.id, line=self.line) def visit_NameConstant(self, n: ast3.NameConstant) -> Type: diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index 109dfe407cf2..cc90e85486a3 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -280,7 +280,7 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile: # arg? kwarg, expr* defaults) @with_line def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: - converter = TypeConverter(self.errors, line=n.lineno) + converter = TypeConverter(self.errors, set(), line=n.lineno) args, decompose_stmts = self.transform_args(n.args, n.lineno) arg_kinds = [arg.kind for arg in args] @@ -378,7 +378,7 @@ def transform_args(self, # TODO: remove the cast once https://github.com/python/typeshed/pull/522 # is accepted and synced type_comments = cast(List[str], n.type_comments) # type: ignore - converter = TypeConverter(self.errors, line=line) + converter = TypeConverter(self.errors, set(), line=line) decompose_stmts = [] # type: List[Statement] def extract_names(arg: ast27.expr) -> List[str]: diff --git a/test-data/unit/check-class-namedtuple.test b/test-data/unit/check-class-namedtuple.test index 710f750369ff..c9f356d28cf8 100644 --- a/test-data/unit/check-class-namedtuple.test +++ b/test-data/unit/check-class-namedtuple.test @@ -499,7 +499,7 @@ class XRepr(NamedTuple): y: int = 1 def __str__(self) -> str: return 'string' - def __add__(self, other: XRepr) -> int: + def __add__(self, other: 'XRepr') -> int: return 0 reveal_type(XMeth(1).double()) # E: Revealed type is 'builtins.int' diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index c7e6f52070c7..c1de1a0358da 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1477,8 +1477,8 @@ from typing import Any def deco(f: Any) -> Any: return f class C: @deco - def __add__(self, other: C) -> C: return C() - def __radd__(self, other: C) -> C: return C() + def __add__(self, other: 'C') -> 'C': return C() + def __radd__(self, other: 'C') -> 'C': return C() [out] [case testReverseOperatorMethodForwardIsAny2] @@ -1486,7 +1486,7 @@ from typing import Any def deco(f: Any) -> Any: return f class C: __add__ = None # type: Any - def __radd__(self, other: C) -> C: return C() + def __radd__(self, other: 'C') -> 'C': return C() [out] [case testReverseOperatorMethodForwardIsAny3] @@ -1494,7 +1494,7 @@ from typing import Any def deco(f: Any) -> Any: return f class C: __add__ = 42 - def __radd__(self, other: C) -> C: return C() + def __radd__(self, other: 'C') -> 'C': return C() [out] main:5: error: Forward operator "__add__" is not callable @@ -1631,7 +1631,7 @@ main:8: error: Signatures of "__iadd__" and "__add__" are incompatible a, b = None, None # type: A, B class A: - def __getattribute__(self, x: str) -> A: + def __getattribute__(self, x: str) -> 'A': return A() class B: pass @@ -1642,11 +1642,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab [case testGetAttributeSignature] class A: - def __getattribute__(self, x: str) -> A: pass + def __getattribute__(self, x: str) -> 'A': pass class B: - def __getattribute__(self, x: A) -> B: pass + def __getattribute__(self, x: A) -> 'B': pass class C: - def __getattribute__(self, x: str, y: str) -> C: pass + def __getattribute__(self, x: str, y: str) -> 'C': pass class D: def __getattribute__(self, x: str) -> None: pass [out] @@ -1657,7 +1657,7 @@ main:6: error: Invalid signature "def (__main__.C, builtins.str, builtins.str) - a, b = None, None # type: A, B class A: - def __getattr__(self, x: str) -> A: + def __getattr__(self, x: str) -> 'A': return A() class B: pass @@ -1668,11 +1668,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab [case testGetAttrSignature] class A: - def __getattr__(self, x: str) -> A: pass + def __getattr__(self, x: str) -> 'A': pass class B: - def __getattr__(self, x: A) -> B: pass + def __getattr__(self, x: A) -> 'B': pass class C: - def __getattr__(self, x: str, y: str) -> C: pass + def __getattr__(self, x: str, y: str) -> 'C': pass class D: def __getattr__(self, x: str) -> None: pass [out] @@ -1776,7 +1776,7 @@ a = a(b) # E: Argument 1 to "__call__" of "A" has incompatible type "B"; expect b = a(a) # E: Incompatible types in assignment (expression has type "A", variable has type "B") class A: - def __call__(self, x: A) -> A: + def __call__(self, x: 'A') -> 'A': pass class B: pass @@ -3280,7 +3280,7 @@ def r(ta: Type[TA], tta: TTA) -> None: class Class(metaclass=M): @classmethod - def f1(cls: Type[Class]) -> None: pass + def f1(cls: Type['Class']) -> None: pass @classmethod def f2(cls: M) -> None: pass cl: Type[Class] = m # E: Incompatible types in assignment (expression has type "M", variable has type Type[Class]) diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index d0c4a56f2038..e7a4898835ac 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -181,7 +181,7 @@ class A: pass class C(A): - def copy(self: C) -> C: + def copy(self: 'C') -> 'C': pass class D(A): @@ -276,10 +276,10 @@ class B: return cls() class C: - def foo(self: C) -> C: return self + def foo(self: 'C') -> 'C': return self @classmethod - def cfoo(cls: Type[C]) -> C: + def cfoo(cls: Type['C']) -> 'C': return cls() class D: @@ -330,21 +330,21 @@ class B: pass class C: - def __new__(cls: Type[C]) -> C: + def __new__(cls: Type['C']) -> 'C': return cls() - def __init_subclass__(cls: Type[C]) -> None: + def __init_subclass__(cls: Type['C']) -> None: pass class D: - def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]' + def __new__(cls: 'D') -> 'D': # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]' return cls - def __init_subclass__(cls: D) -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]' + def __init_subclass__(cls: 'D') -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]' pass class E: - def __new__(cls) -> E: + def __new__(cls) -> 'E': reveal_type(cls) # E: Revealed type is 'def () -> __main__.E' return cls() diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 8c1f85b1d743..0c39cee2f3b9 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -357,7 +357,7 @@ class A(object): self.a = 0 def __iadd__(self, a): - # type: (int) -> A + # type: (int) -> 'A' self.a += 1 return self diff --git a/test-data/unit/check-typevar-values.test b/test-data/unit/check-typevar-values.test index 36df2235a209..54645794ba5f 100644 --- a/test-data/unit/check-typevar-values.test +++ b/test-data/unit/check-typevar-values.test @@ -47,9 +47,9 @@ f(S()) [case testCheckGenericFunctionBodyWithTypeVarValues] from typing import TypeVar class A: - def f(self, x: int) -> A: return self + def f(self, x: int) -> 'A': return self class B: - def f(self, x: int) -> B: return self + def f(self, x: int) -> 'B': return self AB = TypeVar('AB', A, B) def f(x: AB) -> AB: x = x.f(1) @@ -58,11 +58,11 @@ def f(x: AB) -> AB: [case testCheckGenericFunctionBodyWithTypeVarValues2] from typing import TypeVar class A: - def f(self) -> A: return A() - def g(self) -> B: return B() + def f(self) -> 'A': return A() + def g(self) -> 'B': return B() class B: def f(self) -> A: return A() - def g(self) -> B: return B() + def g(self) -> 'B': return B() AB = TypeVar('AB', A, B) def f(x: AB) -> AB: return x.f() # Error @@ -75,11 +75,11 @@ main:12: error: Incompatible return value type (got "B", expected "A") [case testTypeInferenceAndTypeVarValues] from typing import TypeVar class A: - def f(self) -> A: return self - def g(self) -> B: return B() + def f(self) -> 'A': return self + def g(self) -> 'B': return B() class B: - def f(self) -> B: return self - def g(self) -> B: return B() + def f(self) -> 'B': return self + def g(self) -> 'B': return B() AB = TypeVar('AB', A, B) def f(x: AB) -> AB: y = x diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 17d126f1e108..8b96f0d7d320 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -52,7 +52,7 @@ class A: def g(self) -> None: pass [file m.py.2] class A: - def g(self, a: A) -> None: pass + def g(self, a: 'A') -> None: pass [out] == main:4: error: Too few arguments for "g" of "A" diff --git a/test-data/unit/merge.test b/test-data/unit/merge.test index a6d2a424f975..8d703382eedf 100644 --- a/test-data/unit/merge.test +++ b/test-data/unit/merge.test @@ -442,12 +442,12 @@ NameExpr:6: target.A<0> import target [file target.py] class A: - def f(self) -> A: + def f(self) -> 'A': return self.f() [file target.py.next] class A: # Extra line to change line numbers - def f(self) -> A: + def f(self) -> 'A': return self.f() [out] ## target diff --git a/test-data/unit/semanal-errors.test b/test-data/unit/semanal-errors.test index 2192bce3221d..6d1cf8688247 100644 --- a/test-data/unit/semanal-errors.test +++ b/test-data/unit/semanal-errors.test @@ -1401,3 +1401,40 @@ class A: ... # E: Name 'A' already defined on line 2 [builtins fixtures/list.pyi] [out] + +[case testReferringToClassInDefinition] +class A: + def foo(self) -> A: # E: class 'A' is not fully defined; use a forward reference + pass + +[out] + +[case testReferringToClassInClassVar] +class A: + x: A # E: class 'A' is not fully defined; use a forward reference + +[out] + +[case testReferringToClassInNestedStructures] +class A: + class B: + def foo(self, arg: A) -> int: # E: class 'A' is not fully defined; use a forward reference + return 3 + + def bar(self, arg: B) -> int: # E: class 'B' is not fully defined; use a forward reference + return 4 + + def baz(self) -> int: + x: A + + class C: + def qux(self, arg: A) -> int: + return 3 + + def qix(self, arg: C) -> int: # E: class 'C' is not fully defined; use a forward reference + return 3 + + return 3 + +[out] + diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 41359265e64b..f44e03738422 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -108,7 +108,7 @@ a = 1 + 2 1.2 * 3 2.2 - 3 1 / 2 -[file builtins.py] +[file builtins.pyi] class object: def __init__(self) -> None: pass class function: pass @@ -135,7 +135,7 @@ import typing 1 < 2 < 3 8 > 3 4 < 6 > 2 -[file builtins.py] +[file builtins.pyi] class object: def __init__(self) -> None: pass class int: