diff --git a/mypy/checker.py b/mypy/checker.py index bb25f5b95d25..397cc12f7a6a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,9 +2,10 @@ import itertools import fnmatch +from contextlib import contextmanager from typing import ( - Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple + Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator ) from mypy.errors import Errors, report_internal_error @@ -35,10 +36,11 @@ from mypy.sametypes import is_same_type from mypy.messages import MessageBuilder import mypy.checkexpr -from mypy.checkmember import map_type_from_supertype, bind_self +from mypy.checkmember import map_type_from_supertype, bind_self, erase_to_bound from mypy import messages from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away + is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away, + is_subtype_ignoring_tvars ) from mypy.maptype import map_instance_to_supertype from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname @@ -65,7 +67,7 @@ [ ('node', FuncItem), ('context_type_name', Optional[str]), # Name of the surrounding class (for error messages) - ('class_type', Optional[Type]), # And its type (from class_context) + ('active_class', Optional[Type]), # And its type (for selftype handline) ]) @@ -91,19 +93,13 @@ class TypeChecker(NodeVisitor[Type]): # Helper for type checking expressions expr_checker = None # type: mypy.checkexpr.ExpressionChecker - # Class context for checking overriding of a method of the form - # def foo(self: T) -> T - # We need to pass the current class definition for instantiation of T - class_context = None # type: List[Type] - + scope = None # type: Scope # Stack of function return types return_types = None # type: List[Type] # Type context for type inference type_context = None # type: List[Type] # Flags; true for dynamically typed functions dynamic_funcs = None # type: List[bool] - # Stack of functions being type checked - function_stack = None # type: List[FuncItem] # Stack of collections of variables with partial types partial_types = None # type: List[Dict[Var, Context]] globals = None # type: SymbolTable @@ -139,13 +135,12 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option self.path = path self.msg = MessageBuilder(errors, modules) self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg) - self.class_context = [] + self.scope = Scope(tree) self.binder = ConditionalTypeBinder() self.globals = tree.names self.return_types = [] self.type_context = [] self.dynamic_funcs = [] - self.function_stack = [] self.partial_types = [] self.deferred_nodes = [] self.type_map = {} @@ -203,7 +198,7 @@ def check_second_pass(self) -> bool: todo = self.deferred_nodes self.deferred_nodes = [] done = set() # type: Set[FuncItem] - for node, type_name, class_type in todo: + for node, type_name, active_class in todo: if node in done: continue # This is useful for debugging: @@ -212,28 +207,27 @@ def check_second_pass(self) -> bool: done.add(node) if type_name: self.errors.push_type(type_name) - if class_type: - self.class_context.append(class_type) - self.accept(node) - if class_type: - self.class_context.pop() + + if active_class: + with self.scope.push_class(active_class): + self.accept(node) + else: + self.accept(node) if type_name: self.errors.pop_type() return True def handle_cannot_determine_type(self, name: str, context: Context) -> None: - if self.pass_num < LAST_PASS and self.function_stack: + node = self.scope.top_function() + if self.pass_num < LAST_PASS and node is not None: # Don't report an error yet. Just defer. - node = self.function_stack[-1] if self.errors.type_name: type_name = self.errors.type_name[-1] else: type_name = None - if self.class_context: - class_context_top = self.class_context[-1] - else: - class_context_top = None - self.deferred_nodes.append(DeferredNode(node, type_name, class_context_top)) + # Shouldn't we freeze the entire scope? + active_class = self.scope.active_class() + self.deferred_nodes.append(DeferredNode(node, type_name, active_class)) # Set a marker so that we won't infer additional types in this # function. Any inferred types could be bogus, because there's at # least one type that we don't know. @@ -508,7 +502,6 @@ def check_func_item(self, defn: FuncItem, if isinstance(defn, FuncDef): fdef = defn - self.function_stack.append(defn) self.dynamic_funcs.append(defn.is_dynamic() and not type_override) if fdef: @@ -530,7 +523,6 @@ def check_func_item(self, defn: FuncItem, self.errors.pop_function() self.dynamic_funcs.pop() - self.function_stack.pop() self.current_node_deferred = False def check_func_def(self, defn: FuncItem, typ: CallableType, name: str) -> None: @@ -616,14 +608,22 @@ def is_implicit_any(t: Type) -> bool: for i in range(len(typ.arg_types)): arg_type = typ.arg_types[i] - # Refuse covariant parameter type variables - # TODO: check recuresively for inner type variables - if isinstance(arg_type, TypeVarType): - if i > 0: - if arg_type.variance == COVARIANT: - self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, - arg_type) - # FIX: if i == 0 and this is not a method then same as above + ref_type = self.scope.active_class() + if (isinstance(defn, FuncDef) and ref_type is not None and i == 0 + and not defn.is_static + and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2]): + if defn.is_class or defn.name() == '__new__': + ref_type = mypy.types.TypeType(ref_type) + erased = erase_to_bound(arg_type) + if not is_subtype_ignoring_tvars(ref_type, erased): + self.fail("The erased type of self '{}' " + "is not a supertype of its class '{}'" + .format(erased, ref_type), defn) + elif isinstance(arg_type, TypeVarType): + # Refuse covariant parameter type variables + # TODO: check recuresively for inner type variables + if arg_type.variance == COVARIANT: + self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, arg_type) if typ.arg_kinds[i] == nodes.ARG_STAR: # builtins.tuple[T] is typing.Tuple[T, ...] arg_type = self.named_generic_type('builtins.tuple', @@ -642,7 +642,8 @@ def is_implicit_any(t: Type) -> bool: # Type check body in a new scope. with self.binder.top_frame_context(): - self.accept(item.body) + with self.scope.push_function(defn): + self.accept(item.body) unreachable = self.binder.is_unreachable() if (self.options.warn_no_return and not unreachable @@ -888,7 +889,7 @@ def check_method_override_for_base_with_name( # The name of the method is defined in the base class. # Construct the type of the overriding method. - typ = bind_self(self.function_type(defn), self.class_context[-1]) + typ = bind_self(self.function_type(defn), self.scope.active_class()) # Map the overridden method type to subtype context so that # it can be checked for compatibility. original_type = base_attr.type @@ -901,7 +902,7 @@ def check_method_override_for_base_with_name( assert False, str(base_attr.node) if isinstance(original_type, FunctionLike): original = map_type_from_supertype( - bind_self(original_type, self.class_context[-1]), + bind_self(original_type, self.scope.active_class()), defn.info, base) # Check that the types are compatible. # TODO overloaded signatures @@ -985,9 +986,8 @@ def visit_class_def(self, defn: ClassDef) -> Type: old_binder = self.binder self.binder = ConditionalTypeBinder() with self.binder.top_frame_context(): - self.class_context.append(fill_typevars(defn.info)) - self.accept(defn.defs) - self.class_context.pop() + with self.scope.push_class(fill_typevars(defn.info)): + self.accept(defn.defs) self.binder = old_binder if not defn.has_incompatible_baseclass: # Otherwise we've already found errors; more errors are not useful @@ -1519,8 +1519,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: self.binder.unreachable() def check_return_stmt(self, s: ReturnStmt) -> None: - if self.is_within_function(): - defn = self.function_stack[-1] + defn = self.scope.top_function() + if defn is not None: if defn.is_generator: return_type = self.get_generator_return_type(self.return_types[-1], defn.is_coroutine) @@ -1537,7 +1537,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.is_unusable_type(return_type): # Lambdas are allowed to have a unusable returns. # Functions returning a value of type None are allowed to have a Void return. - if isinstance(self.function_stack[-1], FuncExpr) or isinstance(typ, NoneTyp): + if isinstance(self.scope.top_function(), FuncExpr) or isinstance(typ, NoneTyp): return self.fail(messages.NO_RETURN_VALUE_EXPECTED, s) else: @@ -1550,7 +1550,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None: msg=messages.INCOMPATIBLE_RETURN_VALUE_TYPE) else: # Empty returns are valid in Generators with Any typed returns. - if (self.function_stack[-1].is_generator and isinstance(return_type, AnyType)): + if (defn.is_generator and isinstance(return_type, AnyType)): return if isinstance(return_type, (Void, NoneTyp, AnyType)): @@ -2318,13 +2318,6 @@ def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]: return partial_types return None - def is_within_function(self) -> bool: - """Are we currently type checking within a function? - - I.e. not at class body or at the top level. - """ - return self.return_types != [] - def is_unusable_type(self, typ: Type): """Is this type an unusable type? @@ -2756,3 +2749,34 @@ def is_valid_inferred_type_component(typ: Type) -> bool: if not is_valid_inferred_type_component(item): return False return True + + +class Scope: + # We keep two stacks combined, to maintain the relative order + stack = None # type: List[Union[Type, FuncItem, MypyFile]] + + def __init__(self, module: MypyFile) -> None: + self.stack = [module] + + def top_function(self) -> Optional[FuncItem]: + for e in reversed(self.stack): + if isinstance(e, FuncItem): + return e + return None + + def active_class(self) -> Optional[Type]: + if isinstance(self.stack[-1], Type): + return self.stack[-1] + return None + + @contextmanager + def push_function(self, item: FuncItem) -> Iterator[None]: + self.stack.append(item) + yield + self.stack.pop() + + @contextmanager + def push_class(self, t: Type) -> Iterator[None]: + self.stack.append(t) + yield + self.stack.pop() diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3c5b0d73cad6..d5a276077cb2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1604,7 +1604,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type: return AnyType() if not self.chk.in_checked_function(): return AnyType() - args = self.chk.function_stack[-1].arguments + args = self.chk.scope.top_function().arguments # An empty args with super() is an error; we need something in declared_self if not args: self.chk.fail('super() requires at least one positional argument', e) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 415dc9770904..b57cb57ba016 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -614,4 +614,4 @@ def erase_to_bound(t: Type): if isinstance(t, TypeType): if isinstance(t.item, TypeVarType): return TypeType(t.item.upper_bound) - assert not t + return t diff --git a/mypy/semanal.py b/mypy/semanal.py index 999c847cceb9..d4b0df66709c 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -329,7 +329,7 @@ def prepare_method_signature(self, func: FuncDef) -> None: elif isinstance(functype, CallableType): self_type = functype.arg_types[0] if isinstance(self_type, AnyType): - if func.is_class: + if func.is_class or func.name() == '__new__': leading_type = self.class_type(self.type) else: leading_type = fill_typevars(self.type) diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index 77d06843dd60..f97eac4de41c 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -248,3 +248,88 @@ class A: class B(A): def __init__(self, arg: T) -> None: super(B, self).__init__() + +[case testSelfTypeNonsensical] +# flags: --hide-error-context +from typing import TypeVar, Type + +T = TypeVar('T', bound=str) +class A: + def foo(self: T) -> T: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.A' + return self + + @classmethod + def cfoo(cls: Type[T]) -> T: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.A]' + return cls() + +Q = TypeVar('Q', bound='B') +class B: + def foo(self: Q) -> Q: + return self + + @classmethod + def cfoo(cls: Type[Q]) -> Q: + return cls() + +class C: + def foo(self: C) -> C: return self + + @classmethod + def cfoo(cls: Type[C]) -> C: + return cls() + +class D: + def foo(self: str) -> str: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.D' + return self + + @staticmethod + def bar(self: str) -> str: + return self + + @classmethod + def cfoo(cls: Type[str]) -> str: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.D]' + return cls() + +[builtins fixtures/classmethod.pyi] + +[case testSelfTypeLambdaDefault] +# flags: --hide-error-context +from typing import Callable +class C: + @classmethod + def foo(cls, + arg: Callable[[int], str] = lambda a: '' + ) -> None: + pass + + def bar(self, + arg: Callable[[int], str] = lambda a: '' + ) -> None: + pass +[builtins fixtures/classmethod.pyi] + +[case testSelfTypeNew] +# flags: --hide-error-context +from typing import TypeVar, Type + +T = TypeVar('T', bound=A) +class A: + def __new__(cls: Type[T]) -> T: + return cls() + +class B: + def __new__(cls: Type[T]) -> T: # E: The erased type of self 'Type[__main__.A]' is not a supertype of its class 'Type[__main__.B]' + return cls() + +class C: + def __new__(cls: Type[C]) -> C: + return cls() + +class D: + def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]' + return cls + +class E: + def __new__(cls) -> E: + reveal_type(cls) # E: Revealed type is 'def () -> __main__.E' + return cls() diff --git a/test-data/unit/check-super.test b/test-data/unit/check-super.test index 39a1c63a87ee..c6a4eba2908a 100644 --- a/test-data/unit/check-super.test +++ b/test-data/unit/check-super.test @@ -84,22 +84,19 @@ class C(A, B): main: note: In member "f" of class "C": [case testSuperWithNew] +# flags: --hide-error-context class A: def __new__(cls, x: int) -> 'A': return object.__new__(cls) class B(A): - def __new__(cls, x: str) -> 'A': + def __new__(cls, x: int, y: str = '') -> 'A': super().__new__(cls, 1) - super().__new__(cls, x) # E -B('') -B(1) # E + super().__new__(cls, 1, '') # E: Too many arguments for "__new__" of "A" +B('') # E: Argument 1 to "B" has incompatible type "str"; expected "int" +B(1) +B(1, 'x') [builtins fixtures/__new__.pyi] -[out] -main: note: In member "__new__" of class "B": -main:8: error: Argument 2 to "__new__" of "A" has incompatible type "str"; expected "int" -main: note: At top level: -main:10: error: Argument 1 to "B" has incompatible type "int"; expected "str" [case testSuperWithUnknownBase] from typing import Any