diff --git a/mypy/checker.py b/mypy/checker.py index 513a6930a382..de2959d41bc8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -24,27 +24,24 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, AwaitExpr, - CONTRAVARIANT, COVARIANT -) -from mypy.nodes import function_type, method_type, method_type_with_fallback + CONTRAVARIANT, COVARIANT) from mypy import nodes from mypy.types import ( Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType, Instance, NoneTyp, ErrorType, strip_type, UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, - true_only, false_only + true_only, false_only, function_type ) from mypy.sametypes import is_same_type from mypy.messages import MessageBuilder import mypy.checkexpr -from mypy.checkmember import map_type_from_supertype +from mypy.checkmember import map_type_from_supertype, bind_self from mypy import messages from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, - is_more_precise, restrict_subtype_away + is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away ) from mypy.maptype import map_instance_to_supertype -from mypy.semanal import self_type, set_callable_name, refers_to_fullname +from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname from mypy.erasetype import erase_typevars from mypy.expandtype import expand_type from mypy.visitor import NodeVisitor @@ -93,6 +90,11 @@ class TypeChecker(NodeVisitor[Type]): # Helper for type checking expressions expr_checker = None # type: mypy.checkexpr.ExpressionChecker + # Class context for checking overriding of a method of the form + # def foo(self: T) -> T + # We need to pass the current class definition for instantiation of T + class_context = None # type: List[Type] + # Stack of function return types return_types = None # type: List[Type] # Type context for type inference @@ -136,6 +138,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option self.path = path self.msg = MessageBuilder(errors, modules) self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg) + self.class_context = [] self.binder = ConditionalTypeBinder() self.globals = tree.names self.return_types = [] @@ -602,11 +605,13 @@ def is_implicit_any(t: Type) -> bool: arg_type = typ.arg_types[i] # Refuse covariant parameter type variables + # TODO: check recuresively for inner type variables if isinstance(arg_type, TypeVarType): - if arg_type.variance == COVARIANT: - self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, - arg_type) - + if i > 0: + if arg_type.variance == COVARIANT: + self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, + arg_type) + # FIX: if i == 0 and this is not a method then same as above if typ.arg_kinds[i] == nodes.ARG_STAR: # builtins.tuple[T] is typing.Tuple[T, ...] arg_type = self.named_generic_type('builtins.tuple', @@ -788,11 +793,11 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None: method = defn.name() if method not in nodes.inplace_operator_methods: return - typ = self.method_type(defn) + typ = bind_self(self.function_type(defn)) cls = defn.info other_method = '__' + method[3:] if cls.has_readable_member(other_method): - instance = self_type(cls) + instance = fill_typevars(cls) typ2 = self.expr_checker.analyze_external_member_access( other_method, instance, defn) fail = False @@ -868,7 +873,7 @@ def check_method_override_for_base_with_name( # The name of the method is defined in the base class. # Construct the type of the overriding method. - typ = self.method_type(defn) + typ = bind_self(self.function_type(defn), self.class_context[-1]) # Map the overridden method type to subtype context so that # it can be checked for compatibility. original_type = base_attr.type @@ -881,7 +886,7 @@ def check_method_override_for_base_with_name( assert False, str(base_attr.node) if isinstance(original_type, FunctionLike): original = map_type_from_supertype( - method_type(original_type), + bind_self(original_type, self.class_context[-1]), defn.info, base) # Check that the types are compatible. # TODO overloaded signatures @@ -965,7 +970,9 @@ def visit_class_def(self, defn: ClassDef) -> Type: old_binder = self.binder self.binder = ConditionalTypeBinder() with self.binder.top_frame_context(): + self.class_context.append(fill_typevars(defn.info)) self.accept(defn.defs) + self.class_context.pop() self.binder = old_binder if not defn.has_incompatible_baseclass: # Otherwise we've already found errors; more errors are not useful @@ -1012,8 +1019,8 @@ def check_compatibility(self, name: str, base1: TypeInfo, if (isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike)): # Method override - first_sig = method_type(first_type) - second_sig = method_type(second_type) + first_sig = bind_self(first_type) + second_sig = bind_self(second_type) ok = is_subtype(first_sig, second_sig) elif first_type and second_type: ok = is_equivalent(first_type, second_type) @@ -2335,9 +2342,6 @@ def iterable_item_type(self, instance: Instance) -> Type: def function_type(self, func: FuncBase) -> FunctionLike: return function_type(func, self.named_type('builtins.function')) - def method_type(self, func: FuncBase) -> FunctionLike: - return method_type_with_fallback(func, self.named_type('builtins.function')) - # TODO: These next two functions should refer to TypeMap below def find_isinstance_check(self, n: Expression) -> Tuple[Optional[Dict[Expression, Type]], Optional[Dict[Expression, Type]]]: @@ -2350,7 +2354,6 @@ def push_type_map(self, type_map: Optional[Dict[Expression, Type]]) -> None: for expr, type in type_map.items(): self.binder.push(expr, type) - # Data structure returned by find_isinstance_check representing # information learned from the truth or falsehood of a condition. The # dict maps nodes representing expressions like 'a[0].x' to their diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7832c7aa7e98..7b5a043293ed 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -6,7 +6,7 @@ Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef, TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType, PartialType, DeletedType, UnboundType, UninhabitedType, TypeType, - true_only, false_only, is_named_instance + true_only, false_only, is_named_instance, function_type ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -18,7 +18,6 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR2, MODULE_REF, ) -from mypy.nodes import function_type from mypy import nodes import mypy.checker from mypy import types @@ -32,7 +31,6 @@ from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type -from mypy.semanal import self_type from mypy.constraints import get_actual_type from mypy.checkstrformat import StringFormatterChecker from mypy.expandtype import expand_type @@ -1609,10 +1607,18 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type: return AnyType() if not self.chk.in_checked_function(): return AnyType() - return analyze_member_access(e.name, self_type(e.info), e, - is_lvalue, True, False, - self.named_type, self.not_ready_callback, - self.msg, base, chk=self.chk) + args = self.chk.function_stack[-1].arguments + # An empty args with super() is an error; we need something in declared_self + if not args: + self.chk.fail('super() requires at least on positional argument', e) + return AnyType() + declared_self = args[0].variable.type + return analyze_member_access(name=e.name, typ=declared_self, node=e, + is_lvalue=False, is_super=True, is_operator=False, + builtin_type=self.named_type, + not_ready_callback=self.not_ready_callback, + msg=self.msg, override_info=base, chk=self.chk, + original_type=declared_self) else: # Invalid super. This has been reported by the semantic analyzer. return AnyType() diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 96631bed633c..3494f5c37d73 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1,23 +1,23 @@ """Type checking of attribute access""" -from typing import cast, Callable, List, Dict, Optional +from typing import cast, Callable, List, Optional, TypeVar, TYPE_CHECKING from mypy.types import ( - Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarId, TypeVarDef, - Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType, - DeletedType, NoneTyp, TypeType + Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef, + Overloaded, TypeVarType, UnionType, PartialType, + DeletedType, NoneTyp, TypeType, function_type ) from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile -from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, OpExpr, ComparisonExpr -from mypy.nodes import function_type, Decorator, OverloadedFuncDef +from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2 +from mypy.nodes import Decorator, OverloadedFuncDef from mypy.messages import MessageBuilder from mypy.maptype import map_instance_to_supertype -from mypy.expandtype import expand_type_by_instance -from mypy.nodes import method_type, method_type_with_fallback -from mypy.semanal import self_type +from mypy.expandtype import expand_type_by_instance, expand_type +from mypy.infer import infer_type_arguments +from mypy.semanal import fill_typevars from mypy import messages from mypy import subtypes -if False: # import for forward declaration only +if TYPE_CHECKING: # import for forward declaration only import mypy.checker from mypy import experiments @@ -33,17 +33,23 @@ def analyze_member_access(name: str, not_ready_callback: Callable[[str, Context], None], msg: MessageBuilder, override_info: TypeInfo = None, - report_type: Type = None, + original_type: Type = None, chk: 'mypy.checker.TypeChecker' = None) -> Type: - """Analyse attribute access. + """Return the type of attribute `name` of typ. This is a general operation that supports various different variations: 1. lvalue or non-lvalue access (i.e. setter or getter access) 2. supertype access (when using super(); is_super == True and override_info should refer to the supertype) + + original_type is the most precise inferred or declared type of the base object + that we have available. typ is generally a supertype of original_type. + When looking for an attribute of typ, we may perform recursive calls targeting + the fallback type, for example. + original_type is always the type used in the initial call. """ - report_type = report_type or typ + original_type = original_type or typ if isinstance(typ, Instance): if name == '__init__' and not is_super: # Accessing __init__ in statically typed code would compromise @@ -71,20 +77,21 @@ def analyze_member_access(name: str, not_ready_callback) if is_lvalue: msg.cant_assign_to_method(node) - typ = map_instance_to_supertype(typ, method.info) + signature = function_type(method, builtin_type('builtins.function')) if name == '__new__': # __new__ is special and behaves like a static method -- don't strip # the first argument. - signature = function_type(method, builtin_type('builtins.function')) + pass else: - signature = method_type_with_fallback(method, builtin_type('builtins.function')) + signature = bind_self(signature, original_type) + typ = map_instance_to_supertype(typ, method.info) return expand_type_by_instance(signature, typ) else: # Not a method. return analyze_member_var_access(name, typ, info, node, is_lvalue, is_super, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) elif isinstance(typ, AnyType): # The base object has dynamic type. return AnyType() @@ -94,7 +101,7 @@ def analyze_member_access(name: str, # The only attribute NoneType has are those it inherits from object return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue, is_super, is_operator, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) elif isinstance(typ, UnionType): # The base object has dynamic type. msg.disable_type_names += 1 @@ -130,24 +137,25 @@ def analyze_member_access(name: str, # the corresponding method in the current instance to avoid this edge case. # See https://github.com/python/mypy/pull/1787 for more info. result = analyze_class_attribute_access(ret_type, name, node, is_lvalue, - builtin_type, not_ready_callback, msg) + builtin_type, not_ready_callback, msg, + original_type=original_type) if result: return result # Look up from the 'type' type. return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, is_operator, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) else: assert False, 'Unexpected type {}'.format(repr(ret_type)) elif isinstance(typ, FunctionLike): # Look up from the 'function' type. return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, is_operator, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) elif isinstance(typ, TypeVarType): return analyze_member_access(name, typ.upper_bound, node, is_lvalue, is_super, is_operator, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) elif isinstance(typ, DeletedType): msg.deleted_as_rvalue(typ, node) return AnyType() @@ -162,17 +170,18 @@ def analyze_member_access(name: str, if item and not is_operator: # See comment above for why operators are skipped result = analyze_class_attribute_access(item, name, node, is_lvalue, - builtin_type, not_ready_callback, msg) + builtin_type, not_ready_callback, msg, + original_type=original_type) if result: return result fallback = builtin_type('builtins.type') return analyze_member_access(name, fallback, node, is_lvalue, is_super, is_operator, builtin_type, not_ready_callback, msg, - report_type=report_type, chk=chk) + original_type=original_type, chk=chk) if chk and chk.should_suppress_optional_error([typ]): return AnyType() - return msg.has_no_attr(report_type, name, node) + return msg.has_no_attr(original_type, name, node) def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, @@ -180,13 +189,15 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, builtin_type: Callable[[str], Instance], not_ready_callback: Callable[[str, Context], None], msg: MessageBuilder, - report_type: Type = None, + original_type: Type = None, chk: 'mypy.checker.TypeChecker' = None) -> Type: """Analyse attribute access that does not target a method. - This is logically part of analyze_member_access and the arguments are - similar. + This is logically part of analyze_member_access and the arguments are similar. + + original_type is the type of E in the expression E.var """ + original_type = original_type or itype # It was not a method. Try looking up a variable. v = lookup_member_var_or_accessor(info, name, is_lvalue) @@ -202,9 +213,10 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, if not is_lvalue: method = info.get_method('__getattr__') if method: + function = function_type(method, builtin_type('builtins.function')) + bound_method = bind_self(function, original_type) typ = map_instance_to_supertype(itype, method.info) - getattr_type = expand_type_by_instance( - method_type_with_fallback(method, builtin_type('builtins.function')), typ) + getattr_type = expand_type_by_instance(bound_method, typ) if isinstance(getattr_type, CallableType): return getattr_type.ret_type @@ -218,16 +230,20 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, else: if chk and chk.should_suppress_optional_error([itype]): return AnyType() - return msg.has_no_attr(report_type or itype, name, node) + return msg.has_no_attr(original_type, name, node) def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Context, is_lvalue: bool, msg: MessageBuilder, - not_ready_callback: Callable[[str, Context], None]) -> Type: + not_ready_callback: Callable[[str, Context], None], + original_type: Type = None) -> Type: """Analyze access to an attribute via a Var node. This is conceptually part of analyze_member_access and the arguments are similar. + + original_type is the type of E in the expression E.var """ + original_type = original_type or itype # Found a member variable. itype = map_instance_to_supertype(itype, var.info) typ = var.type @@ -252,7 +268,7 @@ def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Cont # class. functype = t check_method_type(functype, itype, var.is_classmethod, node, msg) - signature = method_type(functype) + signature = bind_self(functype, original_type) if var.is_property: # A property cannot have an overloaded type => the cast # is fine. @@ -327,7 +343,9 @@ def analyze_class_attribute_access(itype: Instance, is_lvalue: bool, builtin_type: Callable[[str], Instance], not_ready_callback: Callable[[str, Context], None], - msg: MessageBuilder) -> Type: + msg: MessageBuilder, + original_type: Type = None) -> Type: + """original_type is the type of E in the expression E.var""" node = itype.type.get(name) if not node: if itype.type.fallback_to_any: @@ -350,7 +368,7 @@ def analyze_class_attribute_access(itype: Instance, if isinstance(t, PartialType): return handle_partial_attribute_type(t, is_lvalue, msg, node.node) is_classmethod = is_decorated and cast(Decorator, node.node).func.is_class - return add_class_tvars(t, itype.type, is_classmethod, builtin_type) + return add_class_tvars(t, itype, is_classmethod, builtin_type, original_type) elif isinstance(node.node, Var): not_ready_callback(name, context) return AnyType() @@ -369,24 +387,36 @@ def analyze_class_attribute_access(itype: Instance, return function_type(cast(FuncBase, node.node), builtin_type('builtins.function')) -def add_class_tvars(t: Type, info: TypeInfo, is_classmethod: bool, - builtin_type: Callable[[str], Instance]) -> Type: +def add_class_tvars(t: Type, itype: Instance, is_classmethod: bool, + builtin_type: Callable[[str], Instance], + original_type: Type = None) -> Type: + """Instantiate type variables during analyze_class_attribute_access, + e.g T and Q in the following: + + def A(Generic(T)): + @classmethod + def foo(cls: Type[Q]) -> Tuple[T, Q]: ... + + class B(A): pass + + B.foo() + + original_type is the value of the type B in the expression B.foo() + """ + # TODO: verify consistency betweem Q and T + info = itype.type # type: TypeInfo if isinstance(t, CallableType): # TODO: Should we propagate type variable values? vars = [TypeVarDef(n, i + 1, None, builtin_type('builtins.object'), tv.variance) for (i, n), tv in zip(enumerate(info.type_vars), info.defn.type_vars)] - arg_types = t.arg_types - arg_kinds = t.arg_kinds - arg_names = t.arg_names if is_classmethod: - arg_types = arg_types[1:] - arg_kinds = arg_kinds[1:] - arg_names = arg_names[1:] - return t.copy_modified(arg_types=arg_types, arg_kinds=arg_kinds, arg_names=arg_names, - variables=vars + t.variables) + if not isinstance(original_type, TypeType): + original_type = TypeType(itype) + t = bind_self(t, original_type) + return t.copy_modified(variables=vars + t.variables) elif isinstance(t, Overloaded): - return Overloaded([cast(CallableType, add_class_tvars(i, info, is_classmethod, - builtin_type)) + return Overloaded([cast(CallableType, add_class_tvars(i, itype, is_classmethod, + builtin_type, original_type)) for i in t.items()]) return t @@ -430,7 +460,7 @@ def type_object_type(info: TypeInfo, builtin_type: Callable[[str], Instance]) -> def type_object_type_from_function(init_or_new: FuncBase, info: TypeInfo, fallback: Instance) -> FunctionLike: - signature = method_type_with_fallback(init_or_new, fallback) + signature = bind_self(function_type(init_or_new, fallback)) # The __init__ method might come from a generic superclass # (init_or_new.info) with type variables that do not map @@ -468,7 +498,7 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, variables.extend(init_type.variables) callable_type = init_type.copy_modified( - ret_type=self_type(info), fallback=type_type, name=None, variables=variables, + ret_type=fill_typevars(info), fallback=type_type, name=None, variables=variables, special_sig=special_sig) c = callable_type.with_name('"{}"'.format(info.name())) c.is_classmethod_class = True @@ -489,7 +519,7 @@ def map_type_from_supertype(typ: Type, sub_info: TypeInfo, Now S in the context of D would be mapped to E[T] in the context of C. """ # Create the type of self in subtype, of form t[a1, ...]. - inst_type = self_type(sub_info) + inst_type = fill_typevars(sub_info) if isinstance(inst_type, TupleType): inst_type = inst_type.fallback # Map the type of self to supertype. This gets us a description of the @@ -503,3 +533,82 @@ def map_type_from_supertype(typ: Type, sub_info: TypeInfo, # in inst_type they are interpreted in subtype context. This works even if # the names of type variables in supertype and subtype overlap. return expand_type_by_instance(typ, inst_type) + + +F = TypeVar('F', bound=FunctionLike) + + +def bind_self(method: F, original_type: Type = None) -> F: + """Return a copy of `method`, with the type of its first parameter (usually + self or cls) bound to original_type. + + If the type of `self` is a generic type (T, or Type[T] for classmethods), + instantiate every occurrence of type with original_type in the rest of the + signature and in the return type. + + original_type is the type of E in the expression E.copy(). It is None in + compatibility checks. In this case we treat it as the erasure of the + declared type of self. + + This way we can express "the type of self". For example: + + T = TypeVar('T', bound='A') + class A: + def copy(self: T) -> T: ... + + class B(A): pass + + b = B().copy() # type: B + + """ + if isinstance(method, Overloaded): + return cast(F, Overloaded([bind_self(c, method) for c in method.items()])) + assert isinstance(method, CallableType) + func = method + if not func.arg_types: + # invalid method. return something + return cast(F, func) + if func.arg_kinds[0] == ARG_STAR: + # The signature is of the form 'def foo(*args, ...)'. + # In this case we shouldn'func drop the first arg, + # since func will be absorbed by the *args. + + # TODO: infer bounds on the type of *args? + return cast(F, func) + self_param_type = func.arg_types[0] + if func.variables and (isinstance(self_param_type, TypeVarType) or + (isinstance(self_param_type, TypeType) and + isinstance(self_param_type.item, TypeVarType))): + if original_type is None: + # Type check method override + # XXX value restriction as union? + original_type = erase_to_bound(self_param_type) + + typearg = infer_type_arguments([x.id for x in func.variables], + self_param_type, original_type)[0] + + def expand(target: Type) -> Type: + return expand_type(target, {func.variables[0].id: typearg}) + + arg_types = [expand(x) for x in func.arg_types[1:]] + ret_type = expand(func.ret_type) + variables = func.variables[1:] + else: + arg_types = func.arg_types[1:] + ret_type = func.ret_type + variables = func.variables + res = func.copy_modified(arg_types=arg_types, + arg_kinds=func.arg_kinds[1:], + arg_names=func.arg_names[1:], + variables=variables, + ret_type=ret_type) + return cast(F, res) + + +def erase_to_bound(t: Type): + if isinstance(t, TypeVarType): + return t.upper_bound + if isinstance(t, TypeType): + if isinstance(t.item, TypeVarType): + return TypeType(t.item.upper_bound) + assert not t diff --git a/mypy/constraints.py b/mypy/constraints.py index f204c026a30a..e26e583522ab 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1,6 +1,6 @@ """Type inference constraints.""" -from typing import List, Optional, cast +from typing import List, Optional from mypy.types import ( CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index c299163b747b..7c00aa0cf347 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple, List, cast +from typing import Dict, List, cast from mypy.types import ( Type, Instance, CallableType, TypeVisitor, UnboundType, ErrorType, AnyType, diff --git a/mypy/nodes.py b/mypy/nodes.py index 47ad547ac6a4..28089829adab 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2223,57 +2223,6 @@ def deserialize(cls, data: JsonDict) -> 'SymbolTable': return st -def function_type(func: FuncBase, fallback: 'mypy.types.Instance') -> 'mypy.types.FunctionLike': - if func.type: - assert isinstance(func.type, mypy.types.FunctionLike) - return func.type - else: - # Implicit type signature with dynamic types. - # Overloaded functions always have a signature, so func must be an ordinary function. - fdef = cast(FuncDef, func) - name = func.name() - if name: - name = '"{}"'.format(name) - - return mypy.types.CallableType( - [mypy.types.AnyType()] * len(fdef.arg_names), - fdef.arg_kinds, - fdef.arg_names, - mypy.types.AnyType(), - fallback, - name, - implicit=True, - ) - - -def method_type_with_fallback(func: FuncBase, - fallback: 'mypy.types.Instance') -> 'mypy.types.FunctionLike': - """Return the signature of a method (omit self).""" - return method_type(function_type(func, fallback)) - - -def method_type(sig: 'mypy.types.FunctionLike') -> 'mypy.types.FunctionLike': - if isinstance(sig, mypy.types.CallableType): - return method_callable(sig) - else: - sig = cast(mypy.types.Overloaded, sig) - items = [] # type: List[mypy.types.CallableType] - for c in sig.items(): - items.append(method_callable(c)) - return mypy.types.Overloaded(items) - - -def method_callable(c: 'mypy.types.CallableType') -> 'mypy.types.CallableType': - if c.arg_kinds and c.arg_kinds[0] == ARG_STAR: - # The signature is of the form 'def foo(*args, ...)'. - # In this case we shouldn't drop the first arg, - # since self will be absorbed by the *args. - return c - return c.copy_modified(arg_types=c.arg_types[1:], - arg_kinds=c.arg_kinds[1:], - arg_names=c.arg_names[1:]) - - class MroError(Exception): """Raised if a consistent mro cannot be determined for a class.""" diff --git a/mypy/semanal.py b/mypy/semanal.py index 8fe0dfdde5c6..b0e3a8addc85 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -72,8 +72,8 @@ from mypy.types import ( NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType, FunctionLike, UnboundType, TypeList, TypeVarDef, - replace_leading_arg_type, TupleType, UnionType, StarType, EllipsisType, TypeType) -from mypy.nodes import function_type, implicit_module_attrs + TupleType, UnionType, StarType, EllipsisType, function_type) +from mypy.nodes import implicit_module_attrs from mypy.typeanal import TypeAnalyser, TypeAnalyserPass3, analyze_type_alias from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.sametypes import is_same_type @@ -320,16 +320,19 @@ def visit_func_def(self, defn: FuncDef) -> None: def prepare_method_signature(self, func: FuncDef) -> None: """Check basic signature validity and tweak annotation of self/cls argument.""" # Only non-static methods are special. + functype = func.type if not func.is_static: if not func.arguments: self.fail('Method must have at least one argument', func) - elif func.type: - sig = cast(FunctionLike, func.type) - if func.is_class: - leading_type = self.class_type(self.type) - else: - leading_type = self_type(self.type) - func.type = replace_implicit_first_type(sig, leading_type) + elif isinstance(functype, CallableType): + self_type = functype.arg_types[0] + if isinstance(self_type, AnyType): + if func.is_class: + leading_type = self.class_type(self.type) + else: + leading_type = fill_typevars(self.type) + sig = cast(FunctionLike, func.type) + func.type = replace_implicit_first_type(sig, leading_type) def is_conditional_func(self, previous: Node, new: FuncDef) -> bool: """Does 'new' conditionally redefine 'previous'? @@ -1779,7 +1782,7 @@ def add_field(var: Var, is_initialized_in_class: bool = False, add_field(Var('_source', strtype), is_initialized_in_class=True) # TODO: SelfType should be bind to actual 'self' - this_type = self_type(info) + this_type = fill_typevars(info) def add_method(funcname: str, ret: Type, args: List[Argument], name=None, is_classmethod=False) -> None: @@ -3036,7 +3039,7 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance: return Instance(sym.node, args or []) -def self_type(typ: TypeInfo) -> Union[Instance, TupleType]: +def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: """For a non-generic type, return instance type representing the type. For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn]. """ @@ -3051,7 +3054,7 @@ def self_type(typ: TypeInfo) -> Union[Instance, TupleType]: def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): - return replace_leading_arg_type(sig, new) + return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) else: sig = cast(Overloaded, sig) return Overloaded([cast(CallableType, replace_implicit_first_type(i, new)) diff --git a/mypy/solve.py b/mypy/solve.py index 1ebeb923f5a3..2d24acf00525 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -1,6 +1,7 @@ """Type inference constraint solving""" -from typing import List, Dict +from typing import List, Dict, DefaultDict +from collections import defaultdict from mypy.types import Type, Void, NoneTyp, AnyType, ErrorType, UninhabitedType, TypeVarId from mypy.constraints import Constraint, SUPERTYPE_OF @@ -23,11 +24,9 @@ def solve_constraints(vars: List[TypeVarId], constraints: List[Constraint], pick AnyType. """ # Collect a list of constraints for each type variable. - cmap = {} # type: Dict[TypeVarId, List[Constraint]] + cmap = defaultdict(list) # type: DefaultDict[TypeVarId, List[Constraint]] for con in constraints: - a = cmap.get(con.type_var, []) # type: List[Constraint] - a.append(con) - cmap[con.type_var] = a + cmap[con.type_var].append(con) res = [] # type: List[Type] diff --git a/mypy/subtypes.py b/mypy/subtypes.py index fff5df0f11e5..6fc335d98735 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -140,10 +140,9 @@ def visit_instance(self, left: Instance) -> bool: def visit_type_var(self, left: TypeVarType) -> bool: right = self.right - if isinstance(right, TypeVarType): - return left.id == right.id - else: - return is_subtype(left.upper_bound, self.right) + if isinstance(right, TypeVarType) and left.id == right.id: + return True + return is_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index f63e0ccbe2d5..b795bd5dcfaf 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -68,6 +68,7 @@ 'check-newtype.test', 'check-class-namedtuple.test', 'check-columns.test', + 'check-selftype.test', ] if 'annotation' in typed_ast.ast35.Assign._fields: diff --git a/mypy/types.py b/mypy/types.py index 09e473d213f4..9c80b590cd38 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -181,6 +181,7 @@ class UnboundType(Type): optional = False # is this type a return type? is_ret_type = False + # special case for X[()] empty_tuple_index = False @@ -211,7 +212,7 @@ def serialize(self) -> JsonDict: } @classmethod - def deserialize(self, data: JsonDict) -> 'UnboundType': + def deserialize(cls, data: JsonDict) -> 'UnboundType': assert data['.class'] == 'UnboundType' return UnboundType(data['name'], [Type.deserialize(a) for a in data['args']]) @@ -247,7 +248,7 @@ def serialize(self) -> JsonDict: } @classmethod - def deserialize(self, data: JsonDict) -> 'TypeList': + def deserialize(cls, data: JsonDict) -> 'TypeList': assert data['.class'] == 'TypeList' return TypeList([Type.deserialize(t) for t in data['items']]) @@ -365,7 +366,7 @@ def serialize(self) -> JsonDict: } @classmethod - def deserialize(self, data: JsonDict) -> 'NoneTyp': + def deserialize(cls, data: JsonDict) -> 'NoneTyp': assert data['.class'] == 'NoneTyp' return NoneTyp(is_ret_type=data['is_ret_type']) @@ -401,7 +402,7 @@ def serialize(self) -> JsonDict: 'source': self.source} @classmethod - def deserialize(self, data: JsonDict) -> 'DeletedType': + def deserialize(cls, data: JsonDict) -> 'DeletedType': assert data['.class'] == 'DeletedType' return DeletedType(data['source']) @@ -414,7 +415,7 @@ class Instance(Type): type = None # type: mypy.nodes.TypeInfo args = None # type: List[Type] - erased = False # True if result of type variable substitution + erased = False # True if result of type variable substitution def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], line: int = -1, column: int = -1, erased: bool = False) -> None: @@ -746,7 +747,7 @@ def serialize(self) -> JsonDict: } @classmethod - def deserialize(self, data: JsonDict) -> 'Overloaded': + def deserialize(cls, data: JsonDict) -> 'Overloaded': assert data['.class'] == 'Overloaded' return Overloaded([CallableType.deserialize(t) for t in data['items']]) @@ -955,7 +956,7 @@ def serialize(self) -> JsonDict: return {'.class': 'EllipsisType'} @classmethod - def deserialize(self, data: JsonDict) -> 'EllipsisType': + def deserialize(cls, data: JsonDict) -> 'EllipsisType': assert data['.class'] == 'EllipsisType' return EllipsisType() @@ -1438,14 +1439,6 @@ def strip_type(typ: Type) -> Type: return typ -def replace_leading_arg_type(t: CallableType, self_type: Type) -> CallableType: - """Return a copy of a callable type with a different self argument type. - - Assume that the callable is the signature of a method. - """ - return t.copy_modified(arg_types=[self_type] + t.arg_types[1:]) - - def is_named_instance(t: Type, fullname: str) -> bool: return (isinstance(t, Instance) and t.type is not None and @@ -1507,3 +1500,27 @@ def true_or_false(t: Type) -> Type: new_t.can_be_true = type(new_t).can_be_true new_t.can_be_false = type(new_t).can_be_false return new_t + + +def function_type(func: mypy.nodes.FuncBase, fallback: Instance) -> FunctionLike: + if func.type: + assert isinstance(func.type, FunctionLike) + return func.type + else: + # Implicit type signature with dynamic types. + # Overloaded functions always have a signature, so func must be an ordinary function. + assert isinstance(func, mypy.nodes.FuncItem), str(func) + fdef = cast(mypy.nodes.FuncItem, func) + name = func.name() + if name: + name = '"{}"'.format(name) + + return CallableType( + [AnyType()] * len(fdef.arg_names), + fdef.arg_kinds, + fdef.arg_names, + AnyType(), + fallback, + name, + implicit=True, + ) diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test new file mode 100644 index 000000000000..d7b823695ae5 --- /dev/null +++ b/test-data/unit/check-selftype.test @@ -0,0 +1,238 @@ +[case testSelfTypeInstance] +from typing import TypeVar + +T = TypeVar('T', bound='A', covariant=True) + +class A: + def copy(self: T) -> T: pass + +class B(A): + pass + +reveal_type(A().copy) # E: Revealed type is 'def () -> __main__.A*' +reveal_type(B().copy) # E: Revealed type is 'def () -> __main__.B*' +reveal_type(A().copy()) # E: Revealed type is '__main__.A*' +reveal_type(B().copy()) # E: Revealed type is '__main__.B*' + +[builtins fixtures/bool.pyi] + +[case testSelfTypeStaticAccess] +from typing import TypeVar + +T = TypeVar('T', bound='A', covariant=True) +class A: + def copy(self: T) -> T: pass + +class B(A): + pass + +# Erased instances appear on reveal_type; unrelated to self type +def f(a: A) -> None: pass +f(A.copy(A())) +f(A.copy(B())) +f(B.copy(B())) + +# TODO: make it an error +# f(B.copy(A())) + +def g(a: B) -> None: pass +g(A.copy(A())) # E: Argument 1 to "g" has incompatible type "A"; expected "B" +g(A.copy(B())) +g(B.copy(B())) + +[builtins fixtures/bool.pyi] + +[case testSelfTypeReturn] +# flags: --hide-error-context +from typing import TypeVar, Type + +R = TypeVar('R') +def _type(self: R) -> Type[R]: pass + +T = TypeVar('T', bound='A', covariant=True) +class A: + def copy(self: T) -> T: + if B(): + return A() # E: Incompatible return value type (got "A", expected "T") + elif A(): + return B() # E: Incompatible return value type (got "B", expected "T") + reveal_type(_type(self)) # E: Revealed type is 'Type[T`-1]' + return reveal_type(_type(self)()) # E: Revealed type is 'T`-1' + +class B(A): + pass + +Q = TypeVar('Q', bound='C', covariant=True) +class C: + def __init__(self, a: int) -> None: pass + + def copy(self: Q) -> Q: + if self: + return reveal_type(_type(self)(1)) # E: Revealed type is 'Q`-1' + else: + return _type(self)() # E: Too few arguments for "C" + + +[builtins fixtures/bool.pyi] + +[case testSelfTypeClass] +# flags: --hide-error-context +from typing import TypeVar, Type + +T = TypeVar('T', bound='A') + +class A: + @classmethod + def new(cls: Type[T]) -> T: + return reveal_type(cls()) # E: Revealed type is 'T`-1' + +class B(A): + pass + +Q = TypeVar('Q', bound='C', covariant=True) +class C: + def __init__(self, a: int) -> None: pass + + @classmethod + def new(cls: Type[Q]) -> Q: + if cls: + return cls(1) + else: + return cls() # E: Too few arguments for "C" + + +reveal_type(A.new) # E: Revealed type is 'def () -> __main__.A*' +reveal_type(B.new) # E: Revealed type is 'def () -> __main__.B*' +reveal_type(A.new()) # E: Revealed type is '__main__.A*' +reveal_type(B.new()) # E: Revealed type is '__main__.B*' + +[builtins fixtures/classmethod.pyi] + +[case testSelfTypeOverride] +from typing import TypeVar, cast + +T = TypeVar('T', bound='A', covariant=True) + +class A: + def copy(self: T) -> T: pass + +class B(A): + pass + +Q = TypeVar('Q', bound='C', covariant=True) +class C(A): + def copy(self: Q) -> Q: pass + +reveal_type(C().copy) # E: Revealed type is 'def () -> __main__.C*' +reveal_type(C().copy()) # E: Revealed type is '__main__.C*' +reveal_type(cast(A, C()).copy) # E: Revealed type is 'def () -> __main__.A*' +reveal_type(cast(A, C()).copy()) # E: Revealed type is '__main__.A*' + +[builtins fixtures/bool.pyi] + +[case testSelfTypeSuper] +# flags: --hide-error-context +from typing import TypeVar, cast + +T = TypeVar('T', bound='A', covariant=True) + +class A: + def copy(self: T) -> T: pass + +Q = TypeVar('Q', bound='B', covariant=True) +class B(A): + def copy(self: Q) -> Q: + reveal_type(self) # E: Revealed type is 'Q`-1' + reveal_type(super().copy) # E: Revealed type is 'def () -> Q`-1' + return super().copy() + +[builtins fixtures/bool.pyi] + +[case testSelfTypeRecursiveBinding] +# flags: --hide-error-context +from typing import TypeVar, Callable, Type + +T = TypeVar('T', bound='A', covariant=True) +class A: + # TODO: This is potentially unsafe, as we use T in an argument type + def copy(self: T, factory: Callable[[T], T]) -> T: + return factory(self) + + @classmethod + def new(cls: Type[T], factory: Callable[[T], T]) -> T: + reveal_type(cls) # E: Revealed type is 'Type[T`-1]' + reveal_type(cls()) # E: Revealed type is 'T`-1' + cls(2) # E: Too many arguments for "A" + return cls() + +class B(A): + pass + +reveal_type(A().copy) # E: Revealed type is 'def (factory: def (__main__.A*) -> __main__.A*) -> __main__.A*' +reveal_type(B().copy) # E: Revealed type is 'def (factory: def (__main__.B*) -> __main__.B*) -> __main__.B*' +reveal_type(A.new) # E: Revealed type is 'def (factory: def (__main__.A*) -> __main__.A*) -> __main__.A*' +reveal_type(B.new) # E: Revealed type is 'def (factory: def (__main__.B*) -> __main__.B*) -> __main__.B*' + +[builtins fixtures/classmethod.pyi] + +[case testSelfTypeBound] +# flags: --hide-error-context +from typing import TypeVar, Callable, cast + +TA = TypeVar('TA', bound='A', covariant=True) + +class A: + def copy(self: TA) -> TA: + pass + +class C(A): + def copy(self: C) -> C: + pass + +class D(A): + def copy(self: A) -> A: # E: Return type of "copy" incompatible with supertype "A" + pass + +TB = TypeVar('TB', bound='B', covariant=True) +class B(A): + x = 1 + def copy(self: TB) -> TB: + reveal_type(self.x) # E: Revealed type is 'builtins.int' + return cast(TB, None) + +[builtins fixtures/bool.pyi] + +-- # TODO: fail for this +-- [case testSelfTypeBare] +-- # flags: --hide-error-context +-- from typing import TypeVar, Type +-- +-- T = TypeVar('T', bound='E') +-- +-- class E: +-- def copy(self: T, other: T) -> T: pass + +[case testSelfTypeClone] +# flags: --hide-error-context +from typing import TypeVar, Type + +T = TypeVar('T', bound='C') + +class C: + def copy(self: T) -> T: + return self + + @classmethod + def new(cls: Type[T]) -> T: + return cls() + +def clone(arg: T) -> T: + reveal_type(arg.copy) # E: Revealed type is 'def () -> T`-1' + return arg.copy() + + +def make(cls: Type[T]) -> T: + reveal_type(cls.new) # E: Revealed type is 'def () -> T`-1' + return cls.new() + +[builtins fixtures/classmethod.pyi] diff --git a/test-data/unit/fixtures/classmethod.pyi b/test-data/unit/fixtures/classmethod.pyi index f6333bd7b289..282839dcef28 100644 --- a/test-data/unit/fixtures/classmethod.pyi +++ b/test-data/unit/fixtures/classmethod.pyi @@ -19,3 +19,4 @@ class int: class str: pass class bytes: pass +class bool: pass