Skip to content

Sanity checks for declared selftype #2381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 2, 2016
130 changes: 77 additions & 53 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import itertools
import fnmatch
from contextlib import contextmanager

from typing import (
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator
)

from mypy.errors import Errors, report_internal_error
Expand Down Expand Up @@ -35,10 +36,11 @@
from mypy.sametypes import is_same_type
from mypy.messages import MessageBuilder
import mypy.checkexpr
from mypy.checkmember import map_type_from_supertype, bind_self
from mypy.checkmember import map_type_from_supertype, bind_self, erase_to_bound
from mypy import messages
from mypy.subtypes import (
is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away
is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away,
is_subtype_ignoring_tvars
)
from mypy.maptype import map_instance_to_supertype
from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname
Expand All @@ -65,7 +67,7 @@
[
('node', FuncItem),
('context_type_name', Optional[str]), # Name of the surrounding class (for error messages)
('class_type', Optional[Type]), # And its type (from class_context)
('active_class', Optional[Type]), # And its type (for selftype handline)
])


Expand All @@ -91,19 +93,13 @@ class TypeChecker(NodeVisitor[Type]):
# Helper for type checking expressions
expr_checker = None # type: mypy.checkexpr.ExpressionChecker

# Class context for checking overriding of a method of the form
# def foo(self: T) -> T
# We need to pass the current class definition for instantiation of T
class_context = None # type: List[Type]

scope = None # type: Scope
# Stack of function return types
return_types = None # type: List[Type]
# Type context for type inference
type_context = None # type: List[Type]
# Flags; true for dynamically typed functions
dynamic_funcs = None # type: List[bool]
# Stack of functions being type checked
function_stack = None # type: List[FuncItem]
# Stack of collections of variables with partial types
partial_types = None # type: List[Dict[Var, Context]]
globals = None # type: SymbolTable
Expand Down Expand Up @@ -139,13 +135,12 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
self.path = path
self.msg = MessageBuilder(errors, modules)
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg)
self.class_context = []
self.scope = Scope(tree)
self.binder = ConditionalTypeBinder()
self.globals = tree.names
self.return_types = []
self.type_context = []
self.dynamic_funcs = []
self.function_stack = []
self.partial_types = []
self.deferred_nodes = []
self.type_map = {}
Expand Down Expand Up @@ -203,7 +198,7 @@ def check_second_pass(self) -> bool:
todo = self.deferred_nodes
self.deferred_nodes = []
done = set() # type: Set[FuncItem]
for node, type_name, class_type in todo:
for node, type_name, active_class in todo:
if node in done:
continue
# This is useful for debugging:
Expand All @@ -212,28 +207,27 @@ def check_second_pass(self) -> bool:
done.add(node)
if type_name:
self.errors.push_type(type_name)
if class_type:
self.class_context.append(class_type)
self.accept(node)
if class_type:
self.class_context.pop()

if active_class:
with self.scope.push_class(active_class):
self.accept(node)
else:
self.accept(node)
if type_name:
self.errors.pop_type()
return True

def handle_cannot_determine_type(self, name: str, context: Context) -> None:
if self.pass_num < LAST_PASS and self.function_stack:
node = self.scope.top_function()
if self.pass_num < LAST_PASS and node is not None:
# Don't report an error yet. Just defer.
node = self.function_stack[-1]
if self.errors.type_name:
type_name = self.errors.type_name[-1]
else:
type_name = None
if self.class_context:
class_context_top = self.class_context[-1]
else:
class_context_top = None
self.deferred_nodes.append(DeferredNode(node, type_name, class_context_top))
# Shouldn't we freeze the entire scope?
active_class = self.scope.active_class()
self.deferred_nodes.append(DeferredNode(node, type_name, active_class))
# Set a marker so that we won't infer additional types in this
# function. Any inferred types could be bogus, because there's at
# least one type that we don't know.
Expand Down Expand Up @@ -508,7 +502,6 @@ def check_func_item(self, defn: FuncItem,
if isinstance(defn, FuncDef):
fdef = defn

self.function_stack.append(defn)
self.dynamic_funcs.append(defn.is_dynamic() and not type_override)

if fdef:
Expand All @@ -530,7 +523,6 @@ def check_func_item(self, defn: FuncItem,
self.errors.pop_function()

self.dynamic_funcs.pop()
self.function_stack.pop()
self.current_node_deferred = False

def check_func_def(self, defn: FuncItem, typ: CallableType, name: str) -> None:
Expand Down Expand Up @@ -616,14 +608,22 @@ def is_implicit_any(t: Type) -> bool:
for i in range(len(typ.arg_types)):
arg_type = typ.arg_types[i]

# Refuse covariant parameter type variables
# TODO: check recuresively for inner type variables
if isinstance(arg_type, TypeVarType):
if i > 0:
if arg_type.variance == COVARIANT:
self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT,
arg_type)
# FIX: if i == 0 and this is not a method then same as above
ref_type = self.scope.active_class()
if (isinstance(defn, FuncDef) and ref_type is not None and i == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the FuncDef check is to fix the lambda case I reported. This makes me think that was a pretty close call -- I found 1 case of that in over 300K LoC! And the problem was because check_func_def() is called for 'def' as well as for 'lambda', and there are various other checks to weed out static methods and to do the right thing for class and instance methods, but somehow a lambda anywhere in a class would also be checked.

(I'm just muttering to myself here, I just think it's interesting to see how time after time we find that there are more combinations of features than we can possibly test.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it's "Lesson 8: The State Explosion Problem" :)

This also assumes foo in foo = lambda ... is not a method (from selftype perspective), which I think is fine, but I'd expect to happen once every hundred-thousand lines or so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, foo = lambda ... is questionable, since when calling it via the class or instance, it will be treated as a method. But you can't have types on lambda args so it shouldn't matter what mypy thinks in that case. :-)

and not defn.is_static
and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2]):
if defn.is_class or defn.name() == '__new__':
ref_type = mypy.types.TypeType(ref_type)
erased = erase_to_bound(arg_type)
if not is_subtype_ignoring_tvars(ref_type, erased):
self.fail("The erased type of self '{}' "
"is not a supertype of its class '{}'"
.format(erased, ref_type), defn)
elif isinstance(arg_type, TypeVarType):
# Refuse covariant parameter type variables
# TODO: check recuresively for inner type variables
if arg_type.variance == COVARIANT:
self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, arg_type)
if typ.arg_kinds[i] == nodes.ARG_STAR:
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type('builtins.tuple',
Expand All @@ -642,7 +642,8 @@ def is_implicit_any(t: Type) -> bool:

# Type check body in a new scope.
with self.binder.top_frame_context():
self.accept(item.body)
with self.scope.push_function(defn):
self.accept(item.body)
unreachable = self.binder.is_unreachable()

if (self.options.warn_no_return and not unreachable
Expand Down Expand Up @@ -888,7 +889,7 @@ def check_method_override_for_base_with_name(
# The name of the method is defined in the base class.

# Construct the type of the overriding method.
typ = bind_self(self.function_type(defn), self.class_context[-1])
typ = bind_self(self.function_type(defn), self.scope.active_class())
# Map the overridden method type to subtype context so that
# it can be checked for compatibility.
original_type = base_attr.type
Expand All @@ -901,7 +902,7 @@ def check_method_override_for_base_with_name(
assert False, str(base_attr.node)
if isinstance(original_type, FunctionLike):
original = map_type_from_supertype(
bind_self(original_type, self.class_context[-1]),
bind_self(original_type, self.scope.active_class()),
defn.info, base)
# Check that the types are compatible.
# TODO overloaded signatures
Expand Down Expand Up @@ -985,9 +986,8 @@ def visit_class_def(self, defn: ClassDef) -> Type:
old_binder = self.binder
self.binder = ConditionalTypeBinder()
with self.binder.top_frame_context():
self.class_context.append(fill_typevars(defn.info))
self.accept(defn.defs)
self.class_context.pop()
with self.scope.push_class(fill_typevars(defn.info)):
self.accept(defn.defs)
self.binder = old_binder
if not defn.has_incompatible_baseclass:
# Otherwise we've already found errors; more errors are not useful
Expand Down Expand Up @@ -1519,8 +1519,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type:
self.binder.unreachable()

def check_return_stmt(self, s: ReturnStmt) -> None:
if self.is_within_function():
defn = self.function_stack[-1]
defn = self.scope.top_function()
if defn is not None:
if defn.is_generator:
return_type = self.get_generator_return_type(self.return_types[-1],
defn.is_coroutine)
Expand All @@ -1537,7 +1537,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if self.is_unusable_type(return_type):
# Lambdas are allowed to have a unusable returns.
# Functions returning a value of type None are allowed to have a Void return.
if isinstance(self.function_stack[-1], FuncExpr) or isinstance(typ, NoneTyp):
if isinstance(self.scope.top_function(), FuncExpr) or isinstance(typ, NoneTyp):
return
self.fail(messages.NO_RETURN_VALUE_EXPECTED, s)
else:
Expand All @@ -1550,7 +1550,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
msg=messages.INCOMPATIBLE_RETURN_VALUE_TYPE)
else:
# Empty returns are valid in Generators with Any typed returns.
if (self.function_stack[-1].is_generator and isinstance(return_type, AnyType)):
if (defn.is_generator and isinstance(return_type, AnyType)):
return

if isinstance(return_type, (Void, NoneTyp, AnyType)):
Expand Down Expand Up @@ -2318,13 +2318,6 @@ def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]:
return partial_types
return None

def is_within_function(self) -> bool:
"""Are we currently type checking within a function?

I.e. not at class body or at the top level.
"""
return self.return_types != []

def is_unusable_type(self, typ: Type):
"""Is this type an unusable type?

Expand Down Expand Up @@ -2756,3 +2749,34 @@ def is_valid_inferred_type_component(typ: Type) -> bool:
if not is_valid_inferred_type_component(item):
return False
return True


class Scope:
# We keep two stacks combined, to maintain the relative order
stack = None # type: List[Union[Type, FuncItem, MypyFile]]

def __init__(self, module: MypyFile) -> None:
self.stack = [module]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the module is ever used. If I replace this with [None] all tests seem to pass. But I'll just merge now, you can fix this up later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used right now, but I don't see how None is more suitable as a sentinel in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was the quickest fix -- if I start with an empty stack I have to add guards against an empty stack in active_class(), which currently looks at stack[-1]. But I agree a guard is better.


def top_function(self) -> Optional[FuncItem]:
for e in reversed(self.stack):
if isinstance(e, FuncItem):
return e
return None

def active_class(self) -> Optional[Type]:
if isinstance(self.stack[-1], Type):
return self.stack[-1]
return None

@contextmanager
def push_function(self, item: FuncItem) -> Iterator[None]:
self.stack.append(item)
yield
self.stack.pop()

@contextmanager
def push_class(self, t: Type) -> Iterator[None]:
self.stack.append(t)
yield
self.stack.pop()
2 changes: 1 addition & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
return AnyType()
if not self.chk.in_checked_function():
return AnyType()
args = self.chk.function_stack[-1].arguments
args = self.chk.scope.top_function().arguments
# An empty args with super() is an error; we need something in declared_self
if not args:
self.chk.fail('super() requires at least one positional argument', e)
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,4 @@ def erase_to_bound(t: Type):
if isinstance(t, TypeType):
if isinstance(t.item, TypeVarType):
return TypeType(t.item.upper_bound)
assert not t
return t
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def prepare_method_signature(self, func: FuncDef) -> None:
elif isinstance(functype, CallableType):
self_type = functype.arg_types[0]
if isinstance(self_type, AnyType):
if func.is_class:
if func.is_class or func.name() == '__new__':
leading_type = self.class_type(self.type)
else:
leading_type = fill_typevars(self.type)
Expand Down
85 changes: 85 additions & 0 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,88 @@ class A:
class B(A):
def __init__(self, arg: T) -> None:
super(B, self).__init__()

[case testSelfTypeNonsensical]
# flags: --hide-error-context
from typing import TypeVar, Type

T = TypeVar('T', bound=str)
class A:
def foo(self: T) -> T: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.A'
return self

@classmethod
def cfoo(cls: Type[T]) -> T: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.A]'
return cls()

Q = TypeVar('Q', bound='B')
class B:
def foo(self: Q) -> Q:
return self

@classmethod
def cfoo(cls: Type[Q]) -> Q:
return cls()

class C:
def foo(self: C) -> C: return self

@classmethod
def cfoo(cls: Type[C]) -> C:
return cls()

class D:
def foo(self: str) -> str: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.D'
return self

@staticmethod
def bar(self: str) -> str:
return self

@classmethod
def cfoo(cls: Type[str]) -> str: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.D]'
return cls()

[builtins fixtures/classmethod.pyi]

[case testSelfTypeLambdaDefault]
# flags: --hide-error-context
from typing import Callable
class C:
@classmethod
def foo(cls,
arg: Callable[[int], str] = lambda a: ''
) -> None:
pass

def bar(self,
arg: Callable[[int], str] = lambda a: ''
) -> None:
pass
[builtins fixtures/classmethod.pyi]

[case testSelfTypeNew]
# flags: --hide-error-context
from typing import TypeVar, Type

T = TypeVar('T', bound=A)
class A:
def __new__(cls: Type[T]) -> T:
return cls()

class B:
def __new__(cls: Type[T]) -> T: # E: The erased type of self 'Type[__main__.A]' is not a supertype of its class 'Type[__main__.B]'
return cls()

class C:
def __new__(cls: Type[C]) -> C:
return cls()

class D:
def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
return cls

class E:
def __new__(cls) -> E:
reveal_type(cls) # E: Revealed type is 'def () -> __main__.E'
return cls()
Loading