diff --git a/mypy/checker.py b/mypy/checker.py index 9b0a728552e6..d093293a5b2c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from typing import ( - Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator + Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable ) from mypy.errors import Errors, report_internal_error @@ -30,7 +30,7 @@ Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, Instance, NoneTyp, strip_type, TypeType, TypeOfAny, UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, - true_only, false_only, function_type, is_named_instance, union_items, + true_only, false_only, function_type, is_named_instance, union_items, TypeQuery ) from mypy.sametypes import is_same_type, is_same_types from mypy.messages import MessageBuilder, make_inferred_type_note @@ -55,7 +55,7 @@ from mypy.join import join_types from mypy.treetransform import TransformVisitor from mypy.binder import ConditionalTypeBinder, get_declaration -from mypy.meet import is_overlapping_types, is_partially_overlapping_types +from mypy.meet import is_overlapping_erased_types, is_overlapping_types from mypy.options import Options from mypy.plugin import Plugin, CheckerPluginInterface from mypy.sharedparse import BINARY_MAGIC_METHODS @@ -471,6 +471,9 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: if is_unsafe_overlapping_overload_signatures(sig1, sig2): self.msg.overloaded_signatures_overlap( i + 1, i + j + 2, item.func) + elif is_unsafe_partially_overlapping_overload_signatures(sig1, sig2): + self.msg.overloaded_signatures_partial_overlap( + i + 1, i + j + 2, item.func) if impl_type is not None: assert defn.impl is not None @@ -495,12 +498,12 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # Is the overload alternative's arguments subtypes of the implementation's? if not is_callable_compatible(impl, sig1, - is_compat=is_subtype, + is_compat=is_subtype_no_promote, ignore_return=True): self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl) # Is the overload alternative's return type a subtype of the implementation's? - if not is_subtype(sig1.ret_type, impl.ret_type): + if not is_subtype_no_promote(sig1.ret_type, impl.ret_type): self.msg.overloaded_signatures_ret_specific(i + 1, defn.impl) # Here's the scoop about generators and coroutines. @@ -1043,72 +1046,102 @@ def check_overlapping_op_methods(self, """Check for overlapping method and reverse method signatures. Assume reverse method has valid argument count and kinds. + + Precondition: + If the reverse operator method accepts some argument of type + X, the forward operator method must belong to class X. + + For example, if we have the reverse operator `A.__radd__(B)`, then the + corresponding forward operator must have the type `B.__add__(...)`. """ - # Reverse operator method that overlaps unsafely with the - # forward operator method can result in type unsafety. This is - # similar to overlapping overload variants. + # Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and + # "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping + # by using the following algorithm: + # + # 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1" + # + # 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2" # - # This example illustrates the issue: + # 3. Treat temp1 and temp2 as if they were both variants in the same + # overloaded function. (This mirrors how the Python runtime calls + # operator methods: we first try __OP__, then __rOP__.) # - # class X: pass - # class A: - # def __add__(self, x: X) -> int: - # if isinstance(x, X): - # return 1 - # return NotImplemented - # class B: - # def __radd__(self, x: A) -> str: return 'x' - # class C(X, B): pass - # def f(b: B) -> None: - # A() + b # Result is 1, even though static type seems to be str! - # f(C()) + # If the first signature is unsafely overlapping with the second, + # report an error. # - # The reason for the problem is that B and X are overlapping - # types, and the return types are different. Also, if the type - # of x in __radd__ would not be A, the methods could be - # non-overlapping. + # 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never + # be called), do NOT report an error. + # + # This behavior deviates from how we handle overloads -- many of the + # modules in typeshed seem to define __OP__ methods that shadow the + # corresponding __rOP__ method. + # + # Note: we do not attempt to handle unsafe overlaps related to multiple + # inheritance. for forward_item in union_items(forward_type): if isinstance(forward_item, CallableType): - # TODO check argument kinds - if len(forward_item.arg_types) < 1: - # Not a valid operator method -- can't succeed anyway. - return - - # Construct normalized function signatures corresponding to the - # operator methods. The first argument is the left operand and the - # second operand is the right argument -- we switch the order of - # the arguments of the reverse method. - forward_tweaked = CallableType( - [forward_base, forward_item.arg_types[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - forward_item.ret_type, - forward_item.fallback, - name=forward_item.name) - reverse_args = reverse_type.arg_types - reverse_tweaked = CallableType( - [reverse_args[1], reverse_args[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - reverse_type.ret_type, - fallback=self.named_type('builtins.function'), - name=reverse_type.name) - - if is_unsafe_overlapping_operator_signatures( - forward_tweaked, reverse_tweaked): + if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type): self.msg.operator_method_signatures_overlap( reverse_class, reverse_name, forward_base, forward_name, context) elif isinstance(forward_item, Overloaded): for item in forward_item.items(): - self.check_overlapping_op_methods( - reverse_type, reverse_name, reverse_class, - item, forward_name, forward_base, context) + if self.is_unsafe_overlapping_op(item, forward_base, reverse_type): + self.msg.operator_method_signatures_overlap( + reverse_class, reverse_name, + forward_base, forward_name, + context) elif not isinstance(forward_item, AnyType): self.msg.forward_operator_not_callable(forward_name, context) + def is_unsafe_overlapping_op(self, + forward_item: CallableType, + forward_base: Type, + reverse_type: CallableType) -> bool: + # TODO check argument kinds + if len(forward_item.arg_types) < 1: + # Not a valid operator method -- can't succeed anyway. + return False + + # Erase the type if necessary to make sure we don't have a dangling + # TypeVar in forward_tweaked + forward_base_erased = forward_base + if isinstance(forward_base, TypeVarType): + forward_base_erased = erase_to_bound(forward_base) + + # Construct normalized function signatures corresponding to the + # operator methods. The first argument is the left operand and the + # second operand is the right argument -- we switch the order of + # the arguments of the reverse method. + + forward_tweaked = forward_item.copy_modified( + arg_types=[forward_base_erased, forward_item.arg_types[0]], + arg_kinds=[nodes.ARG_POS] * 2, + arg_names=[None] * 2, + ) + reverse_tweaked = reverse_type.copy_modified( + arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]], + arg_kinds=[nodes.ARG_POS] * 2, + arg_names=[None] * 2, + ) + + reverse_base_erased = reverse_type.arg_types[0] + if isinstance(reverse_base_erased, TypeVarType): + reverse_base_erased = erase_to_bound(reverse_base_erased) + + if is_same_type(reverse_base_erased, forward_base_erased): + return False + elif is_subtype(reverse_base_erased, forward_base_erased): + first = reverse_tweaked + second = forward_tweaked + else: + first = forward_tweaked + second = reverse_tweaked + + return is_unsafe_partially_overlapping_overload_signatures(first, second) + def check_inplace_operator_method(self, defn: FuncBase) -> None: """Check an inplace operator method such as __iadd__. @@ -3088,7 +3121,7 @@ def find_isinstance_check(self, node: Expression else: optional_type, comp_type = second_type, first_type optional_expr = node.operands[1] - if is_overlapping_types(optional_type, comp_type): + if is_overlapping_erased_types(optional_type, comp_type): return {optional_expr: remove_optional(optional_type)}, {} elif node.operators in [['in'], ['not in']]: expr = node.operands[0] @@ -3099,7 +3132,7 @@ def find_isinstance_check(self, node: Expression right_type.type.fullname() != 'builtins.object')) if (right_type and right_ok and is_optional(left_type) and literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_types(left_type, right_type)): + is_overlapping_erased_types(left_type, right_type)): if node.operators == ['in']: return {expr: remove_optional(left_type)}, {} if node.operators == ['not in']: @@ -3442,7 +3475,8 @@ def conditional_type_map(expr: Expression, and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges return {}, None - elif not is_overlapping_types(current_type, proposed_type): + elif not is_overlapping_types(current_type, proposed_type, + prohibit_none_typevar_overlap=True): # Expression is never of any type in proposed_type_ranges return None, {} else: @@ -3658,37 +3692,155 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool: def is_unsafe_overlapping_overload_signatures(signature: CallableType, other: CallableType) -> bool: - """Check if two overloaded function signatures may be unsafely overlapping. + """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. - We consider two functions 's' and 't' to be unsafely overlapping both if + We consider two functions 's' and 't' to be unsafely overlapping if both of the following are true: 1. s's parameters are all more precise or partially overlapping with t's 2. s's return type is NOT a subtype of t's. + This function will perform a modified version of the above two checks: + we do not check for partial overlaps. This lets us vary our error messages + depending on the severity of the overlap. + + See 'is_unsafe_partially_overlapping_overload_signatures' for the full checks. + Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ - # TODO: Handle partially overlapping parameter types - # - # For example, the signatures "f(x: Union[A, B]) -> int" and "f(x: Union[B, C]) -> str" - # is unsafe: the parameter types are partially overlapping. - # - # To fix this, we need to either modify meet.is_overlapping_types or add a new - # function and use "is_more_precise(...) or is_partially_overlapping(...)" for the is_compat - # checks. - # - # (We already have a rudimentary implementation of 'is_partially_overlapping', but it only - # attempts to handle the obvious cases -- see its docstring for more info.) + # if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # print("in first") + + signature = detach_callable(signature) + other = detach_callable(other) + + return (is_callable_compatible(signature, other, + is_compat=is_more_precise_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), + ignore_return=False, + check_args_covariantly=True, + allow_partial_overlap=True) or + is_callable_compatible(other, signature, + is_compat=is_more_precise_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), + ignore_return=False, + check_args_covariantly=False, + allow_partial_overlap=True)) + + +def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, + other: CallableType) -> bool: + """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. + + We consider two functions 's' and 't' to be unsafely overlapping if both + of the following are true: + + 1. s's parameters are all more precise or partially overlapping with t's + 2. s's return type is NOT a subtype of t's. + + Assumes that 'signature' appears earlier in the list of overload + alternatives then 'other' and that their argument counts are overlapping. + """ + # if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # print("in second") def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: - return is_more_precise(t, s) or is_partially_overlapping_types(t, s) + return is_more_precise_no_promote(t, s) or is_overlapping_types_no_promote(t, s) - return is_callable_compatible(signature, other, + # Try detaching callables from the containing class so that all TypeVars + # are treated as being free. + # + # This lets us identify cases where the two signatures use completely + # incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars + # test case. + signature = detach_callable(signature) + other = detach_callable(other) + + # Note: We repeat this check twice in both directions due to a slight + # asymmetry in 'is_callable_compatible'. When checking for partial overlaps, + # we attempt to unify 'signature' and 'other' both against each other. + # + # If 'signature' cannot be unified with 'other', we end early. However, + # if 'other' cannot be modified with 'signature', the function continues + # using the older version of 'other'. + # + # This discrepancy is unfortunately difficult to get rid of, so we repeat the + # checks twice in both directions for now. + return (is_callable_compatible(signature, other, is_compat=is_more_precise_or_partially_overlapping, - is_compat_return=lambda l, r: not is_subtype(l, r), + is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), + ignore_return=False, check_args_covariantly=True, - allow_partial_overlap=True) + allow_partial_overlap=True) or + is_callable_compatible(other, signature, + is_compat=is_more_precise_or_partially_overlapping, + is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), + ignore_return=False, + check_args_covariantly=False, + allow_partial_overlap=True)) + + +def detach_callable(typ: CallableType) -> CallableType: + """Ensures that the callable's type variables are 'detached' and independent of the context. + + A callable normally keeps track of the type variables it uses within its 'variables' field. + However, if the callable is from a method and that method is using a class type variable, + the callable will not keep track of that type variable since it belongs to the class. + + This function will traverse the callable and find all used type vars and add them to the + variables field if it isn't already present. + + The caller can then unify on all type variables whether or not the callable is originally + from a class or not.""" + type_list = typ.arg_types + [typ.ret_type] + # old_type_list = list(type_list) + + appear_map = {} # type: Dict[str, List[int]] + for i, inner_type in enumerate(type_list): + typevars_available = inner_type.accept(TypeVarExtractor()) + for var in typevars_available: + if var.fullname not in appear_map: + appear_map[var.fullname] = [] + appear_map[var.fullname].append(i) + + used_type_var_names = set() + for var_name, appearances in appear_map.items(): + used_type_var_names.add(var_name) + + all_type_vars = typ.accept(TypeVarExtractor()) + new_variables = [] + for var in set(all_type_vars): + if var.fullname not in used_type_var_names: + continue + new_variables.append(TypeVarDef( + name=var.name, + fullname=var.fullname, + id=var.id, + values=var.values, + upper_bound=var.upper_bound, + variance=var.variance, + )) + out = typ.copy_modified( + variables=new_variables, + arg_types=type_list[:-1], + ret_type=type_list[-1], + ) + return out + + +class TypeVarExtractor(TypeQuery[List[TypeVarType]]): + def __init__(self) -> None: + super().__init__(self._merge) + + def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]: + out = [] + for item in iter: + out.extend(item) + return out + + def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]: + return [t] def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: @@ -3704,69 +3856,6 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo ignore_return=True) -def is_unsafe_overlapping_operator_signatures(signature: Type, other: Type) -> bool: - """Check if two operator method signatures may be unsafely overlapping. - - Two signatures s and t are overlapping if both can be valid for the same - statically typed values and the return types are incompatible. - - Assume calls are first checked against 'signature', then against 'other'. - Thus if 'signature' is more general than 'other', there is no unsafe - overlapping. - - TODO: Clean up this function and make it not perform type erasure. - - Context: This function was previously used to make sure both overloaded - functions and operator methods were not unsafely overlapping. - - We changed the semantics for we should handle overloaded definitions, - but not operator functions. (We can't reuse the same semantics for both: - the overload semantics are too restrictive here). - - We should rewrite this method so that: - - 1. It uses many of the improvements made to overloads: in particular, - eliminating type erasure. - - 2. It contains just the logic necessary for operator methods. - """ - if isinstance(signature, CallableType): - if isinstance(other, CallableType): - # TODO varargs - # TODO keyword args - # TODO erasure - # TODO allow to vary covariantly - # Check if the argument counts are overlapping. - min_args = max(signature.min_args, other.min_args) - max_args = min(len(signature.arg_types), len(other.arg_types)) - if min_args > max_args: - # Argument counts are not overlapping. - return False - # Signatures are overlapping iff if they are overlapping for the - # smallest common argument count. - for i in range(min_args): - t1 = signature.arg_types[i] - t2 = other.arg_types[i] - if not is_overlapping_types(t1, t2): - return False - # All arguments types for the smallest common argument count are - # overlapping => the signature is overlapping. The overlapping is - # safe if the return types are identical. - if is_same_type(signature.ret_type, other.ret_type): - return False - # If the first signature has more general argument types, the - # latter will never be called - if is_more_general_arg_prefix(signature, other): - return False - # Special case: all args are subtypes, and returns are subtypes - if (all(is_proper_subtype(s, o) - for (s, o) in zip(signature.arg_types, other.arg_types)) and - is_subtype(signature.ret_type, other.ret_type)): - return False - return not is_more_precise_signature(signature, other) - return True - - def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: """Does t have wider arguments than s?""" # TODO should an overload with additional items be allowed to be more @@ -3784,20 +3873,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: return False -def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool: - """Are type variable definitions equivalent? - - Ignore ids, locations in source file and names. - """ - return ( - tv1.variance == tv2.variance - and is_same_types(tv1.values, tv2.values) - and ((tv1.upper_bound is None and tv2.upper_bound is None) - or (tv1.upper_bound is not None - and tv2.upper_bound is not None - and is_same_type(tv1.upper_bound, tv2.upper_bound)))) - - def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: return is_callable_compatible(t, s, is_compat=is_same_type, @@ -3806,21 +3881,6 @@ def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: ignore_pos_arg_names=True) -def is_more_precise_signature(t: CallableType, s: CallableType) -> bool: - """Is t more precise than s? - A signature t is more precise than s if all argument types and the return - type of t are more precise than the corresponding types in s. - Assume that the argument kinds and names are compatible, and that the - argument counts are overlapping. - """ - # TODO generic function types - # Only consider the common prefix of argument types. - for argt, args in zip(t.arg_types, s.arg_types): - if not is_more_precise(argt, args): - return False - return is_more_precise(t.ret_type, s.ret_type) - - def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]: """Determine if operator assignment on given value type is in-place, and the method name. @@ -3962,3 +4022,15 @@ def is_static(func: Union[FuncBase, Decorator]) -> bool: elif isinstance(func, FuncBase): return func.is_static assert False, "Unexpected func type: {}".format(type(func)) + + +def is_subtype_no_promote(left: Type, right: Type) -> bool: + return is_subtype(left, right, ignore_promotions=True) + + +def is_more_precise_no_promote(left: Type, right: Type) -> bool: + return is_more_precise(left, right, ignore_promotions=True) + + +def is_overlapping_types_no_promote(left: Type, right: Type) -> bool: + return is_overlapping_types(left, right, ignore_promotions=True) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 15d0a58f2b3b..024056c9e983 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -42,7 +42,9 @@ from mypy import join from mypy.meet import narrow_declared_type from mypy.maptype import map_instance_to_supertype -from mypy.subtypes import is_subtype, is_equivalent, find_member, non_method_protocol_members +from mypy.subtypes import ( + is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members, +) from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type, bind_self @@ -1806,20 +1808,6 @@ def get_operator_method(self, op: str) -> str: else: return nodes.op_methods[op] - def _check_op_for_errors(self, method: str, base_type: Type, arg: Expression, - context: Context - ) -> Tuple[Tuple[Type, Type], MessageBuilder]: - """Type check a binary operation which maps to a method call. - - Return ((result type, inferred operator method type), error message). - """ - local_errors = self.msg.copy() - local_errors.disable_count = 0 - result = self.check_op_local(method, base_type, - arg, context, - local_errors) - return result, local_errors - def check_op_local(self, method: str, base_type: Type, arg: Expression, context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: """Type check a binary operation which maps to a method call. @@ -1840,6 +1828,210 @@ def check_op_local(self, method: str, base_type: Type, arg: Expression, context, arg_messages=local_errors, callable_name=callable_name, object_type=object_type) + def check_op_reversible(self, + op_name: str, + left_type: Type, + left_expr: Expression, + right_type: Type, + right_expr: Expression, + context: Context) -> Tuple[Type, Type]: + # Note: this kludge exists mostly to maintain compatibility with + # existing error messages. Apparently, if the left-hand-side is a + # union and we have a type mismatch, we print out a special, + # abbreviated error message. (See messages.unsupported_operand_types). + unions_present = isinstance(left_type, UnionType) + + def make_local_errors() -> MessageBuilder: + """Creates a new MessageBuilder object.""" + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + if unions_present: + local_errors.disable_type_names += 1 + return local_errors + + def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: + """Looks up the given operator and returns the corresponding type, + if it exists.""" + if not self.has_member(base_type, op_name): + return None + local_errors = make_local_errors() + member = analyze_member_access( + name=op_name, + typ=base_type, + node=context, + is_lvalue=False, + is_super=False, + is_operator=True, + builtin_type=self.named_type, + not_ready_callback=self.not_ready_callback, + msg=local_errors, + original_type=base_type, + chk=self.chk, + ) + if local_errors.is_errors(): + return None + else: + return member + + def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: + """Returns the name of the class that contains the actual definition of attr_name. + + So if class A defines foo and class B subclasses A, running + 'get_class_defined_in(B, "foo")` would return the full name of A. + + However, if B were to override and redefine foo, that method call would + return the full name of B instead. + + If the attr name is not present in the given class or its MRO, returns None. + """ + mro = typ.type.mro + if mro is None: + return None + + for cls in mro: + if cls.names.get(attr_name): + return cls.fullname() + return None + + # If either the LHS or the RHS are Any, we can't really concluding anything + # about the operation since the Any type may or may not define an + # __op__ or __rop__ method. So, we punt and return Any instead. + + if isinstance(left_type, AnyType): + any_type = AnyType(TypeOfAny.from_another_any, source_any=left_type) + return any_type, any_type + if isinstance(right_type, AnyType): + any_type = AnyType(TypeOfAny.from_another_any, source_any=right_type) + return any_type, any_type + + rev_op_name = self.get_reverse_op_method(op_name) + + # STEP 1: + # We start by getting the __op__ and __rop__ methods, if they exist. + + # Records the method type, the base type, and the argument. + variants_raw = [] # type: List[Tuple[Optional[Type], Type, Expression]] + + left_op = lookup_operator(op_name, left_type) + right_op = lookup_operator(rev_op_name, right_type) + + # STEP 2a: + # We figure out in which order Python will call the operator methods. As it + # turns out, it's not as simple as just trying to call __op__ first and + # __rop__ second. + + warn_about_uncalled_reverse_operator = False + bias_right = is_proper_subtype(right_type, left_type) + if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type): + # When we do "A() + A()", for example, Python will only call the __add__ method, + # never the __radd__ method. + # + # This is the case even if the __add__ method is completely missing and the __radd__ + # method is defined. + # + # We report this error message here instead of in the definition checks + + variants_raw.append((left_op, left_type, right_expr)) + if right_op is not None: + warn_about_uncalled_reverse_operator = True + elif (is_subtype(right_type, left_type) + and isinstance(left_type, Instance) + and isinstance(right_type, Instance) + and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name)): + # When we do "A() + B()" where B is a subclass of B, we'll actually try calling + # B's __radd__ method first, but ONLY if B explicitly defines or overrides the + # __radd__ method. + # + # This mechanism lets subclasses "refine" the expected outcome of the operation, even + # if they're located on the RHS. + + variants_raw.append((right_op, right_type, left_expr)) + variants_raw.append((left_op, left_type, right_expr)) + else: + # In all other cases, we do the usual thing and call __add__ first and + # __radd__ second when doing "A() + B()". + + variants_raw.append((left_op, left_type, right_expr)) + variants_raw.append((right_op, right_type, left_expr)) + + # STEP 2b: + # When running Python 2, we might also try calling the __cmp__ method. + + is_python_2 = self.chk.options.python_version[0] == 2 + if is_python_2 and op_name in nodes.ops_falling_back_to_cmp: + cmp_method = nodes.comparison_fallback_method + left_cmp_op = lookup_operator(cmp_method, left_type) + right_cmp_op = lookup_operator(cmp_method, right_type) + + if bias_right: + variants_raw.append((right_cmp_op, right_type, left_expr)) + variants_raw.append((left_cmp_op, left_type, right_expr)) + else: + variants_raw.append((left_cmp_op, left_type, right_expr)) + variants_raw.append((right_cmp_op, right_type, left_expr)) + + # STEP 3: + # We now filter out all non-existant operators. The 'variants' list contains + # all operator methods that are actually present, in the order that Python + # attempts to invoke them. + + variants = [(op, obj, arg) for (op, obj, arg) in variants_raw if op is not None] + + # STEP 4: + # We now try invoking each one. If an operation succeeds, end early and return + # the corresponding result. Otherwise, return the result and errors associated + # with the first entry. + + errors = [] + results = [] + for method, obj, arg in variants: + local_errors = make_local_errors() + + callable_name = None # type: Optional[str] + if isinstance(obj, Instance): + # TODO: Find out in which class the method was defined originally? + # TODO: Support non-Instance types. + callable_name = '{}.{}'.format(obj.type.fullname(), op_name) + + result = self.check_call(method, [arg], [nodes.ARG_POS], + context, arg_messages=local_errors, + callable_name=callable_name, object_type=obj) + if local_errors.is_errors(): + errors.append(local_errors) + results.append(result) + else: + return result + + # STEP 4b: + # Sometimes, the variants list is empty. In that case, we fall-back to attempting to + # call the __op__ method (even though it's missing). + + if len(errors) == 0: + local_errors = make_local_errors() + result = self.check_op_local(op_name, left_type, right_expr, context, local_errors) + + if local_errors.is_errors(): + errors.append(local_errors) + results.append(result) + else: + return result + + self.msg.add_errors(errors[0]) + if warn_about_uncalled_reverse_operator: + self.msg.reverse_operator_method_never_called( + nodes.op_methods_to_symbols[op_name], + op_name, + right_type, + rev_op_name, + context, + ) + if len(results) == 1: + return results[0] + else: + error_any = AnyType(TypeOfAny.from_error) + result = error_any, error_any + return result + def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, allow_reverse: bool = False) -> Tuple[Type, Type]: @@ -1847,82 +2039,23 @@ def check_op(self, method: str, base_type: Type, arg: Expression, Return tuple (result type, inferred operator method type). """ - # Use a local error storage for errors related to invalid argument - # type (but NOT other errors). This error may need to be suppressed - # for operators which support __rX methods. - local_errors = self.msg.copy() - local_errors.disable_count = 0 - if not allow_reverse or self.has_member(base_type, method): - result = self.check_op_local(method, base_type, arg, context, - local_errors) - if allow_reverse: - arg_type = self.chk.type_map[arg] - if isinstance(arg_type, AnyType): - # If the right operand has type Any, we can't make any - # conjectures about the type of the result, since the - # operand could have a __r method that returns anything. - any_type = AnyType(TypeOfAny.from_another_any, source_any=arg_type) - result = any_type, result[1] - success = not local_errors.is_errors() - else: - error_any = AnyType(TypeOfAny.from_error) - result = error_any, error_any - success = False - if success or not allow_reverse or isinstance(base_type, AnyType): - # We were able to call the normal variant of the operator method, - # or there was some problem not related to argument type - # validity, or the operator has no __rX method. In any case, we - # don't need to consider the __rX method. - self.msg.add_errors(local_errors) - return result + + if allow_reverse: + return self.check_op_reversible( + op_name=method, + left_type=base_type, + left_expr=TempNode(base_type), + right_type=self.accept(arg), + right_expr=arg, + context=context) else: - # Calling the operator method was unsuccessful. Try the __rX - # method of the other operand instead. - rmethod = self.get_reverse_op_method(method) - arg_type = self.accept(arg) - base_arg_node = TempNode(base_type) - # In order to be consistent with showing an error about the lhs not matching if neither - # the lhs nor the rhs have a compatible signature, we keep track of the first error - # message generated when considering __rX methods and __cmp__ methods for Python 2. - first_error = None # type: Optional[Tuple[Tuple[Type, Type], MessageBuilder]] - if self.has_member(arg_type, rmethod): - result, local_errors = self._check_op_for_errors(rmethod, arg_type, - base_arg_node, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - # If we've failed to find an __rX method and we're checking Python 2, check to see if - # there is a __cmp__ method on the lhs or on the rhs. - if (self.chk.options.python_version[0] == 2 and - method in nodes.ops_falling_back_to_cmp): - cmp_method = nodes.comparison_fallback_method - if self.has_member(base_type, cmp_method): - # First check the if the lhs has a __cmp__ method that works - result, local_errors = self._check_op_for_errors(cmp_method, base_type, - arg, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - if self.has_member(arg_type, cmp_method): - # Failed to find a __cmp__ method on the lhs, check if - # the rhs as a __cmp__ method that can operate on lhs - result, local_errors = self._check_op_for_errors(cmp_method, arg_type, - base_arg_node, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - if first_error: - # We found either a __rX method, a __cmp__ method on the base_type, or a __cmp__ - # method on the rhs and failed match. Return the error for the first of these to - # fail. - self.msg.add_errors(first_error[1]) - return first_error[0] - else: - # No __rX method or __cmp__. Do deferred type checking to - # produce error message that we may have missed previously. - # TODO Fix type checking an expression more than once. - return self.check_op_local(method, base_type, arg, context, - self.msg) + return self.check_op_local( + method=method, + base_type=base_type, + arg=arg, + context=context, + local_errors=self.msg, + ) def get_reverse_op_method(self, method: str) -> str: if method == '__div__' and self.chk.options.python_version[0] == 2: diff --git a/mypy/meet.py b/mypy/meet.py index abd0a97a3c6b..4ee31e39abe4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -5,9 +5,14 @@ from mypy.types import ( Type, AnyType, TypeVisitor, UnboundType, NoneTyp, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, - UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike + UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, ) -from mypy.subtypes import is_equivalent, is_subtype, is_protocol_implementation, is_proper_subtype +from mypy.subtypes import ( + is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, + is_proper_subtype, +) +from mypy.erasetype import erase_type +from mypy.maptype import map_instance_to_supertype from mypy import experiments @@ -32,7 +37,7 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if isinstance(declared, UnionType): return UnionType.make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) - elif not is_overlapping_types(declared, narrowed, use_promotions=True): + elif not is_overlapping_types(declared, narrowed): if experiments.STRICT_OPTIONAL: return UninhabitedType() else: @@ -49,155 +54,246 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return narrowed -def is_partially_overlapping_types(t: Type, s: Type) -> bool: - """Returns 'true' if the two types are partially, but not completely, overlapping. - - NOTE: This function is only a partial implementation. - - It exists mostly so that overloads correctly handle partial - overlaps for the more obvious cases. - """ - # Are unions partially overlapping? - if isinstance(t, UnionType) and isinstance(s, UnionType): - t_set = set(t.items) - s_set = set(s.items) - num_same = len(t_set.intersection(s_set)) - num_diff = len(t_set.symmetric_difference(s_set)) - return num_same > 0 and num_diff > 0 - - # Are tuples partially overlapping? - tup_overlap = is_overlapping_tuples(t, s, use_promotions=True) - if tup_overlap is not None and tup_overlap: - return tup_overlap - - def is_object(t: Type) -> bool: - return isinstance(t, Instance) and t.type.fullname() == 'builtins.object' - - # Is either 't' or 's' an unrestricted TypeVar? - if isinstance(t, TypeVarType) and is_object(t.upper_bound) and len(t.values) == 0: - return True +def get_possible_variants(typ: Type) -> List[Type]: + """This function takes any "Union-like" type and returns a list of the available "options". - if isinstance(s, TypeVarType) and is_object(s.upper_bound) and len(s.values) == 0: - return True + Specifically, there are currently exactly three different types that can have + "variants" or are "union-like": - return False + - Unions + - TypeVars with value restrictions + - Overloads + This function will return a list of each "option" present in those types. -def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool: - """Can a value of type t be a value of type s, or vice versa? + If this function receives any other type, we return a list containing just that + original type. (E.g. pretend the type was contained within a singleton union). - Note that this effectively checks against erased types, since type - variables are erased at runtime and the overlapping check is based - on runtime behavior. The exception is protocol types, it is not safe, - but convenient and is an opt-in behavior. + The only exception is regular TypeVars: we return a list containing that TypeVar's + upper bound. - If use_promotions is True, also consider type promotions (int and - float would only be overlapping if it's True). + This function is useful primarily when checking to see if two types are overlapping: + the algorithm to check if two unions are overlapping is fundamentally the same as + the algorithm for checking if two overloads are overlapping. - This does not consider multiple inheritance. For example, A and B in - the following example are not considered overlapping, even though - via C they can be overlapping: + Normalizing both kinds of types in the same way lets us reuse the same algorithm + for both. + """ + if isinstance(typ, TypeVarType): + if len(typ.values) > 0: + return typ.values + else: + return [typ.upper_bound] + elif isinstance(typ, UnionType): + return typ.relevant_items() + elif isinstance(typ, Overloaded): + # Note: doing 'return typ.items()' makes mypy + # infer a too-specific return type of List[CallableType] + out = [] # type: List[Type] + out.extend(typ.items()) + return out + else: + return [typ] - class A: ... - class B: ... - class C(A, B): ... - The rationale is that this case is usually very unlikely as multiple - inheritance is rare. Also, we can't reliably determine whether - multiple inheritance actually occurs somewhere in a program, due to - stub files hiding implementation details, dynamic loading etc. +def is_overlapping_types(left: Type, + right: Type, + ignore_promotions: bool = False, + prohibit_none_typevar_overlap: bool = False) -> bool: + """Can a value of type 'left' also be of type 'right' or vice-versa? - TODO: Don't consider callables always overlapping. - TODO: Don't consider type variables with values always overlapping. + If 'ignore_promotions' is True, we ignore promotions while checking for overlaps. + If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with + TypeVars (in both strict-optional and non-strict-optional mode). """ - # Any overlaps with everything - if isinstance(t, AnyType) or isinstance(s, AnyType): + + def _is_overlapping_types(left: Type, right: Type) -> bool: + '''Encode the kind of overlapping check to perform. + + This function mostly exists so we don't have to repeat keyword arguments everywhere.''' + return is_overlapping_types( + left, right, + ignore_promotions=ignore_promotions, + prohibit_none_typevar_overlap=prohibit_none_typevar_overlap) + + # We should never encounter these types, but if we do, we handle + # them in the same way we handle 'Any'. + illegal_types = (UnboundType, PartialType, ErasedType, DeletedType) + if isinstance(left, illegal_types) or isinstance(right, illegal_types): + # TODO: Replace this with an 'assert False' once we are confident we + # never accidentally generate these types. return True - # object overlaps with everything - if (isinstance(t, Instance) and t.type.fullname() == 'builtins.object' or - isinstance(s, Instance) and s.type.fullname() == 'builtins.object'): + + # 'Any' may or may not be overlapping with the other type + if isinstance(left, AnyType) or isinstance(right, AnyType): return True - if is_proper_subtype(t, s) or is_proper_subtype(s, t): + # We check for complete overlaps first as a general-purpose failsafe. + # If this check fails, we start checking to see if there exists a + # *partial* overlap between types. + # + # These checks will also handle the NoneTyp and UninhabitedType cases for us. + + if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions) + or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)): return True - # Since we are effectively working with the erased types, we only - # need to handle occurrences of TypeVarType at the top level. - if isinstance(t, TypeVarType): - t = t.erase_to_union_or_bound() - if isinstance(s, TypeVarType): - s = s.erase_to_union_or_bound() - if isinstance(t, TypedDictType): - t = t.as_anonymous().fallback - if isinstance(s, TypedDictType): - s = s.as_anonymous().fallback - - if isinstance(t, UnionType): - return any(is_overlapping_types(item, s) - for item in t.relevant_items()) - if isinstance(s, UnionType): - return any(is_overlapping_types(t, item) - for item in s.relevant_items()) - - # We must check for TupleTypes before Instances, since Tuple[A, ...] - # is an Instance - tup_overlap = is_overlapping_tuples(t, s, use_promotions) - if tup_overlap is not None: - return tup_overlap - - if isinstance(t, Instance): - if isinstance(s, Instance): - # Consider two classes non-disjoint if one is included in the mro - # of another. - if use_promotions: - # Consider cases like int vs float to be overlapping where - # there is only a type promotion relationship but not proper - # subclassing. - if t.type._promote and is_overlapping_types(t.type._promote, s): - return True - if s.type._promote and is_overlapping_types(s.type._promote, t): + + # See the docstring for 'get_possible_variants' for more info on what the + # following lines are doing. + + left_possible = get_possible_variants(left) + right_possible = get_possible_variants(right) + + # We start by checking multi-variant types like Unions first. We also perform + # the same logic if either type happens to be a TypeVar. + # + # Handling the TypeVars now lets us simulate having them bind to the corresponding + # type -- if we deferred these checks, the "return-early" logic of the other + # checks will prevent us from detecting certain overlaps. + # + # If both types are singleton variants (and are not TypeVars), we've hit the base case: + # we skip these checks to avoid infinitely recursing. + + def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: + return isinstance(t1, NoneTyp) and isinstance(t2, TypeVarType) + + if prohibit_none_typevar_overlap: + if is_none_typevar_overlap(left, right) or is_none_typevar_overlap(right, left): + return False + + if (len(left_possible) > 1 or len(right_possible) > 1 + or isinstance(left, TypeVarType) or isinstance(right, TypeVarType)): + for l in left_possible: + for r in right_possible: + if _is_overlapping_types(l, r): return True - if t.type in s.type.mro or s.type in t.type.mro: - return True - if t.type.is_protocol and is_protocol_implementation(s, t): - return True - if s.type.is_protocol and is_protocol_implementation(t, s): - return True + return False + + # Now that we've finished handling TypeVars, we're free to end early + # if one one of the types is None and we're running in strict-optional + # mode. (We must perform this check after the TypeVar checks because + # a TypeVar could be bound to None, for example.) + + if experiments.STRICT_OPTIONAL: + if isinstance(left, NoneTyp) != isinstance(right, NoneTyp): return False - if isinstance(t, TypeType) and isinstance(s, TypeType): - # If both types are TypeType, compare their inner types. - return is_overlapping_types(t.item, s.item, use_promotions) - elif isinstance(t, TypeType) or isinstance(s, TypeType): - # If exactly only one of t or s is a TypeType, check if one of them - # is an `object` or a `type` and otherwise assume no overlap. - one = t if isinstance(t, TypeType) else s - other = s if isinstance(t, TypeType) else t - if isinstance(other, Instance): - return other.type.fullname() in {'builtins.object', 'builtins.type'} + + # Next, we handle single-variant types that may be inherently partially overlapping: + # + # - TypedDicts + # - Tuples + # + # If we cannot identify a partial overlap and end early, we degrade these two types + # into their 'Instance' fallbacks. + + if isinstance(left, TypedDictType) and isinstance(right, TypedDictType): + return are_typed_dicts_overlapping(left, right, ignore_promotions=ignore_promotions) + elif isinstance(left, TypedDictType): + left = left.fallback + elif isinstance(right, TypedDictType): + right = right.fallback + + if is_tuple(left) and is_tuple(right): + return are_tuples_overlapping(left, right, ignore_promotions=ignore_promotions) + elif isinstance(left, TupleType): + left = left.fallback + elif isinstance(right, TupleType): + right = right.fallback + + # Next, we handle single-variant types that cannot be inherently partially overlapping, + # but do require custom logic to inspect. + # + # As before, we degrade into 'Instance' whenever possible. + + if isinstance(left, TypeType) and isinstance(right, TypeType): + # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? + return _is_overlapping_types(left.item, right.item) + + if isinstance(left, CallableType) and isinstance(right, CallableType): + return is_callable_compatible(left, right, + is_compat=_is_overlapping_types, + ignore_pos_arg_names=True, + allow_partial_overlap=True) + elif isinstance(left, CallableType): + left = left.fallback + elif isinstance(right, CallableType): + right = right.fallback + + # Finally, we handle the case where left and right are instances. + + if isinstance(left, Instance) and isinstance(right, Instance): + if left.type.is_protocol and is_protocol_implementation(right, left): + return True + if right.type.is_protocol and is_protocol_implementation(left, right): + return True + + # Two unrelated types cannot be partially overlapping: they're disjoint. + # We don't need to handle promotions because they've already been handled + # by the calls to `is_subtype(...)` up above (and promotable types never + # have any generic arguments we need to recurse on). + if left.type.has_base(right.type.fullname()): + left = map_instance_to_supertype(left, right.type) + elif right.type.has_base(left.type.fullname()): + right = map_instance_to_supertype(right, left.type) else: - return isinstance(other, CallableType) and is_subtype(other, one) - if experiments.STRICT_OPTIONAL: - if isinstance(t, NoneTyp) != isinstance(s, NoneTyp): - # NoneTyp does not overlap with other non-Union types under strict Optional checking return False - # We conservatively assume that non-instance, non-union, non-TupleType and non-TypeType types - # can overlap any other types. + + if len(left.args) == len(right.args): + for left_arg, right_arg in zip(left.args, right.args): + if _is_overlapping_types(left_arg, right_arg): + return True + + # We ought to have handled every case by now: we conclude the + # two types are not overlapping, either completely or partially. + + return False + + +def is_overlapping_erased_types(left: Type, right: Type, *, + ignore_promotions: bool = False) -> bool: + """The same as 'is_overlapping_erased_types', except the types are erased first.""" + return is_overlapping_types(erase_type(left), erase_type(right), + ignore_promotions=ignore_promotions) + + +def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, + ignore_promotions: bool = False) -> bool: + """Returns 'true' if left and right are overlapping TypeDictTypes.""" + # All required keys in left are present and overlapping with something in right + for key in left.required_keys: + if key not in right.items: + return False + if not is_overlapping_types(left.items[key], right.items[key], + ignore_promotions=ignore_promotions): + return False + + # Repeat check in the other direction + for key in right.required_keys: + if key not in left.items: + return False + if not is_overlapping_types(left.items[key], right.items[key], + ignore_promotions=ignore_promotions): + return False + + # The presence of any additional optional keys does not affect whether the two + # TypedDicts are partially overlapping: the dicts would be overlapping if the + # keys happened to be missing. return True -def is_overlapping_tuples(t: Type, s: Type, use_promotions: bool) -> Optional[bool]: - """Part of is_overlapping_types(), for tuples only""" - t = adjust_tuple(t, s) or t - s = adjust_tuple(s, t) or s - if isinstance(t, TupleType) or isinstance(s, TupleType): - if isinstance(t, TupleType) and isinstance(s, TupleType): - if t.length() == s.length(): - if all(is_overlapping_types(ti, si, use_promotions) - for ti, si in zip(t.items, s.items)): - return True - # TupleType and non-tuples do not overlap +def are_tuples_overlapping(left: Type, right: Type, *, + ignore_promotions: bool = False) -> bool: + """Returns true if left and right are overlapping tuples. + + Precondition: is_tuple(left) and is_tuple(right) are both true.""" + left = adjust_tuple(left, right) or left + right = adjust_tuple(right, left) or right + assert isinstance(left, TupleType) + assert isinstance(right, TupleType) + if len(left.items) != len(right.items): return False - # No tuples are involved here - return None + return all(is_overlapping_types(l, r, ignore_promotions=ignore_promotions) + for l, r in zip(left.items, right.items)) def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: @@ -208,6 +304,11 @@ def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: return None +def is_tuple(typ: Type) -> bool: + return (isinstance(typ, TupleType) + or (isinstance(typ, Instance) and typ.type.fullname() == 'builtins.tuple')) + + class TypeMeetVisitor(TypeVisitor[Type]): def __init__(self, s: Type) -> None: self.s = s diff --git a/mypy/messages.py b/mypy/messages.py index 27dada9477c9..a33debee153a 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -973,6 +973,12 @@ def overloaded_signatures_overlap(self, index1: int, index2: int, context: Conte self.fail('Overloaded function signatures {} and {} overlap with ' 'incompatible return types'.format(index1, index2), context) + def overloaded_signatures_partial_overlap(self, index1: int, index2: int, + context: Context) -> None: + self.fail('Overloaded function signatures {} and {} '.format(index1, index2) + + 'are partially overlapping: the two signatures may return ' + + 'incompatible types given certain calls', context) + def overloaded_signature_will_never_match(self, index1: int, index2: int, context: Context) -> None: self.fail( @@ -994,6 +1000,22 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No self.fail('Overloaded function implementation cannot produce return type ' 'of signature {}'.format(index), context) + def reverse_operator_method_never_called(self, + op: str, + forward_method: str, + reverse_type: Type, + reverse_method: str, + context: Context) -> None: + msg = "{rfunc} will not be called when running '{cls} {op} {cls}': must define {ffunc}" + self.note( + msg.format( + op=op, + ffunc=forward_method, + rfunc=reverse_method, + cls=self.format_bare(reverse_type), + ), + context=context) + def operator_method_signatures_overlap( self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type, forward_method: str, context: Context) -> None: diff --git a/mypy/nodes.py b/mypy/nodes.py index 9f8d10666251..a0acd73714b4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1471,6 +1471,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: 'in': '__contains__', } # type: Dict[str, str] +op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} + comparison_fallback_method = '__cmp__' ops_falling_back_to_cmp = {'__ne__', '__eq__', '__lt__', '__le__', @@ -1506,6 +1508,26 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: '__le__': '__ge__', } +# Suppose we have some class A. When we do A() + A(), Python will only check +# the output of A().__add__(A()) and skip calling the __radd__ method entirely. +# This shortcut is used only for the following methods: +op_methods_that_shortcut = { + '__add__', + '__sub__', + '__mul__', + '__truediv__', + '__mod__', + '__divmod__', + '__floordiv__', + '__pow__', + '__matmul__', + '__and__', + '__or__', + '__xor__', + '__lshift__', + '__rshift__', +} + normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) reverse_op_method_set = set(reverse_op_methods.values()) diff --git a/mypy/sametypes.py b/mypy/sametypes.py index b382c632ffe3..ef053a5b4b19 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -98,7 +98,8 @@ def visit_callable_type(self, left: CallableType) -> bool: def visit_tuple_type(self, left: TupleType) -> bool: if isinstance(self.right, TupleType): - return is_same_types(left.items, self.right.items) + return (is_same_type(left.fallback, self.right.fallback) + and is_same_types(left.items, self.right.items)) else: return False diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7c57579dee98..5c10a49d467c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -20,7 +20,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance from mypy.sametypes import is_same_type -from mypy.typestate import TypeState +from mypy.typestate import TypeState, SubtypeKind from mypy import experiments @@ -46,7 +46,8 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: def is_subtype(left: Type, right: Type, type_parameter_checker: TypeParameterChecker = check_type_parameter, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> bool: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -66,7 +67,9 @@ def is_subtype(left: Type, right: Type, # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. is_subtype_of_item = any(is_subtype(left, item, type_parameter_checker, - ignore_pos_arg_names=ignore_pos_arg_names) + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) for item in right.items) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -81,7 +84,8 @@ def is_subtype(left: Type, right: Type, # otherwise, fall through return left.accept(SubtypeVisitor(right, type_parameter_checker, ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance)) + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions)) def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool: @@ -106,11 +110,43 @@ class SubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, type_parameter_checker: TypeParameterChecker, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> None: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> None: self.right = right self.check_type_parameter = type_parameter_checker self.ignore_pos_arg_names = ignore_pos_arg_names self.ignore_declared_variance = ignore_declared_variance + self.ignore_promotions = ignore_promotions + self._subtype_kind = SubtypeVisitor.build_subtype_kind( + type_parameter_checker=type_parameter_checker, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + + @staticmethod + def build_subtype_kind(*, + type_parameter_checker: TypeParameterChecker = check_type_parameter, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> SubtypeKind: + return hash(('subtype', + type_parameter_checker, + ignore_pos_arg_names, + ignore_declared_variance, + ignore_promotions)) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_subtype(self, left: Type, right: Type) -> bool: + return is_subtype(left, right, + type_parameter_checker=self.check_type_parameter, + ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_declared_variance=self.ignore_declared_variance, + ignore_promotions=self.ignore_promotions) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -150,19 +186,17 @@ def visit_instance(self, left: Instance) -> bool: return True right = self.right if isinstance(right, TupleType) and right.fallback.type.is_enum: - return is_subtype(left, right.fallback) + return self._is_subtype(left, right.fallback) if isinstance(right, Instance): - if TypeState.is_cached_subtype_check(left, right): + if self._lookup_cache(left, right): return True # NOTE: left.type.mro may be None in quick mode if there # was an error somewhere. - if left.type.mro is not None: + if not self.ignore_promotions and left.type.mro is not None: for base in left.type.mro: # TODO: Also pass recursively ignore_declared_variance - if base._promote and is_subtype( - base._promote, self.right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): - TypeState.record_subtype_cache_entry(left, right) + if base._promote and self._is_subtype(base._promote, self.right): + self._record_cache(left, right) return True rname = right.type.fullname() # Always try a nominal check if possible, @@ -175,7 +209,7 @@ def visit_instance(self, left: Instance) -> bool: for lefta, righta, tvar in zip(t.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if right.type.is_protocol and is_protocol_implementation(left, right): return True @@ -185,7 +219,7 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(item, TupleType): item = item.fallback if is_named_instance(left, 'builtins.type'): - return is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) + return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) if left.type.is_metaclass(): if isinstance(item, AnyType): return True @@ -195,7 +229,7 @@ def visit_instance(self, left: Instance) -> bool: # Special case: Instance can be a subtype of Callable. call = find_member('__call__', left, left) if call: - return is_subtype(call, right) + return self._is_subtype(call, right) return False else: return False @@ -204,27 +238,24 @@ def visit_type_var(self, left: TypeVarType) -> bool: right = self.right if isinstance(right, TypeVarType) and left.id == right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), right): + if left.values and self._is_subtype(UnionType.make_simplified_union(left.values), right): return True - return is_subtype(left.upper_bound, self.right) + return self._is_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): return is_callable_compatible( left, right, - is_compat=is_subtype, + is_compat=self._is_subtype, ignore_pos_arg_names=self.ignore_pos_arg_names) elif isinstance(right, Overloaded): - return all(is_subtype(left, item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) - for item in right.items()) + return all(self._is_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_subtype(left.fallback, right, - ignore_pos_arg_names=self.ignore_pos_arg_names) + return is_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) else: return False @@ -242,17 +273,17 @@ def visit_tuple_type(self, left: TupleType) -> bool: iter_type = right.args[0] else: iter_type = AnyType(TypeOfAny.special_form) - return all(is_subtype(li, iter_type) for li in left.items) - elif is_subtype(left.fallback, right, self.check_type_parameter): + return all(self._is_subtype(li, iter_type) for li in left.items) + elif self._is_subtype(left.fallback, right): return True return False elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_subtype(l, r, self.check_type_parameter): + if not self._is_subtype(l, r): return False - if not is_subtype(left.fallback, right.fallback, self.check_type_parameter): + if not self._is_subtype(left.fallback, right.fallback): return False return True else: @@ -261,7 +292,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: def visit_typeddict_type(self, left: TypedDictType) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right, self.check_type_parameter) + return self._is_subtype(left.fallback, right) elif isinstance(right, TypedDictType): if not left.names_are_wider_than(right): return False @@ -287,11 +318,10 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_overloaded(self, left: Overloaded) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right) + return self._is_subtype(left.fallback, right) elif isinstance(right, CallableType): for item in left.items(): - if is_subtype(item, right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): + if self._is_subtype(item, right): return True return False elif isinstance(right, Overloaded): @@ -304,8 +334,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: found_match = False for left_index, left_item in enumerate(left.items()): - subtype_match = is_subtype(left_item, right_item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) + subtype_match = self._is_subtype(left_item, right_item)\ # Order matters: we need to make sure that the index of # this item is at least the index of the previous one. @@ -320,10 +349,10 @@ def visit_overloaded(self, left: Overloaded) -> bool: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. if (is_callable_compatible(left_item, right_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names) or is_callable_compatible(right_item, left_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names)): # If this is an overload that's already been matched, there's no # problem. @@ -344,13 +373,12 @@ def visit_overloaded(self, left: Overloaded) -> bool: # All the items must have the same type object status, so # it's sufficient to query only (any) one of them. # This is unsound, we don't check all the __init__ signatures. - return left.is_type_obj() and is_subtype(left.items()[0], right) + return left.is_type_obj() and self._is_subtype(left.items()[0], right) else: return False def visit_union_type(self, left: UnionType) -> bool: - return all(is_subtype(item, self.right, self.check_type_parameter) - for item in left.items) + return all(self._is_subtype(item, self.right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. @@ -359,10 +387,10 @@ def visit_partial_type(self, left: PartialType) -> bool: def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): - return is_subtype(left.item, right.item) + return self._is_subtype(left.item, right.item) if isinstance(right, CallableType): # This is unsound, we don't check the __init__ signature. - return is_subtype(left.item, right.ret_type) + return self._is_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() in ['builtins.object', 'builtins.type']: return True @@ -371,7 +399,7 @@ def visit_type_type(self, left: TypeType) -> bool: item = item.upper_bound if isinstance(item, Instance): metaclass = item.type.metaclass_type - return metaclass is not None and is_subtype(metaclass, right) + return metaclass is not None and self._is_subtype(metaclass, right) return False @@ -426,6 +454,8 @@ def f(self) -> A: ... return False if not proper_subtype: # Nominal check currently ignores arg names + # NOTE: If we ever change this, be sure to also change the call to + # SubtypeVisitor.build_subtype_kind(...) down below. is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) else: is_compat = is_proper_subtype(subtype, supertype) @@ -447,10 +477,13 @@ def f(self) -> A: ... # This rule is copied from nominal check in checker.py if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: return False - if proper_subtype: - TypeState.record_proper_subtype_cache_entry(left, right) + + if not proper_subtype: + # Nominal check currently ignores arg names + subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=True) else: - TypeState.record_subtype_cache_entry(left, right) + subtype_kind = ProperSubtypeVisitor.build_subtype_kind() + TypeState.record_subtype_cache_entry(subtype_kind, left, right) return True @@ -926,6 +959,7 @@ def unify_generic_callable(type: CallableType, target: CallableType, c = mypy.constraints.infer_constraints( type.ret_type, target.ret_type, return_constraint_direction) constraints.extend(c) + type_var_ids = [tvar.id for tvar in type.variables] inferred_vars = mypy.solve.solve_constraints(type_var_ids, constraints) if None in inferred_vars: @@ -964,21 +998,38 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: return t -def is_proper_subtype(left: Type, right: Type) -> bool: +def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Is left a proper subtype of right? For proper subtypes, there's no need to rely on compatibility due to Any types. Every usable type is a proper subtype of itself. """ if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any([is_proper_subtype(left, item) + return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions) for item in right.items]) - return left.accept(ProperSubtypeVisitor(right)) + return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions)) class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type) -> None: + def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None: self.right = right + self.ignore_promotions = ignore_promotions + self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( + ignore_promotions=ignore_promotions, + ) + + @staticmethod + def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind: + return hash(('subtype_proper', ignore_promotions)) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_proper_subtype(self, left: Type, right: Type) -> bool: + return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions) def visit_unbound_type(self, left: UnboundType) -> bool: # This can be called if there is a bad type annotation. The result probably @@ -1009,19 +1060,20 @@ def visit_deleted_type(self, left: DeletedType) -> bool: def visit_instance(self, left: Instance) -> bool: right = self.right if isinstance(right, Instance): - if TypeState.is_cached_proper_subtype_check(left, right): + if self._lookup_cache(left, right): return True - for base in left.type.mro: - if base._promote and is_proper_subtype(base._promote, right): - TypeState.record_proper_subtype_cache_entry(left, right) - return True + if not self.ignore_promotions: + for base in left.type.mro: + if base._promote and self._is_proper_subtype(base._promote, right): + self._record_cache(left, right) + return True if left.type.has_base(right.type.fullname()): def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if variance == COVARIANT: - return is_proper_subtype(leftarg, rightarg) + return self._is_proper_subtype(leftarg, rightarg) elif variance == CONTRAVARIANT: - return is_proper_subtype(rightarg, leftarg) + return self._is_proper_subtype(rightarg, leftarg) else: return sametypes.is_same_type(leftarg, rightarg) # Map left type to corresponding right instances. @@ -1030,7 +1082,7 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_proper_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if (right.type.is_protocol and is_protocol_implementation(left, right, proper_subtype=True)): @@ -1039,29 +1091,30 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if isinstance(right, CallableType): call = find_member('__call__', left, left) if call: - return is_proper_subtype(call, right) + return self._is_proper_subtype(call, right) return False return False def visit_type_var(self, left: TypeVarType) -> bool: if isinstance(self.right, TypeVarType) and left.id == self.right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right): + if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right, + ignore_promotions=self.ignore_promotions): return True - return is_proper_subtype(left.upper_bound, self.right) + return self._is_proper_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): - return is_callable_compatible(left, right, is_compat=is_proper_subtype) + return is_callable_compatible(left, right, is_compat=self._is_proper_subtype) elif isinstance(right, Overloaded): - return all(is_proper_subtype(left, item) + return all(self._is_proper_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_proper_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_proper_subtype(left.ret_type, right.item) return False def visit_tuple_type(self, left: TupleType) -> bool: @@ -1079,15 +1132,15 @@ def visit_tuple_type(self, left: TupleType) -> bool: # TODO: We shouldn't need this special case. This is currently needed # for isinstance(x, tuple), though it's unclear why. return True - return all(is_proper_subtype(li, iter_type) for li in left.items) - return is_proper_subtype(left.fallback, right) + return all(self._is_proper_subtype(li, iter_type) for li in left.items) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_proper_subtype(l, r): + if not self._is_proper_subtype(l, r): return False - return is_proper_subtype(left.fallback, right.fallback) + return self._is_proper_subtype(left.fallback, right.fallback) return False def visit_typeddict_type(self, left: TypedDictType) -> bool: @@ -1100,14 +1153,14 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if name not in left.items: return False return True - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) def visit_overloaded(self, left: Overloaded) -> bool: # TODO: What's the right thing to do here? return False def visit_union_type(self, left: UnionType) -> bool: - return all([is_proper_subtype(item, self.right) for item in left.items]) + return all([self._is_proper_subtype(item, self.right) for item in left.items]) def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? @@ -1118,10 +1171,10 @@ def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return is_proper_subtype(left.item, right.item) + return self._is_proper_subtype(left.item, right.item) if isinstance(right, CallableType): # This is also unsound because of __init__. - return right.is_type_obj() and is_proper_subtype(left.item, right.ret_type) + return right.is_type_obj() and self._is_proper_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() == 'builtins.type': # TODO: Strictly speaking, the type builtins.type is considered equivalent to @@ -1134,7 +1187,7 @@ def visit_type_type(self, left: TypeType) -> bool: return False -def is_more_precise(left: Type, right: Type) -> bool: +def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Check if left is a more precise type than right. A left is a proper subtype of right, left is also more precise than @@ -1144,4 +1197,4 @@ def is_more_precise(left: Type, right: Type) -> bool: # TODO Should List[int] be more precise than List[Any]? if isinstance(right, AnyType): return True - return is_proper_subtype(left, right) + return is_proper_subtype(left, right, ignore_promotions=ignore_promotions) diff --git a/mypy/types.py b/mypy/types.py index 54d7d9a9e40b..e3fa4e9d0e1c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1193,6 +1193,9 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType': set(data['required_keys']), Instance.deserialize(data['fallback'])) + def has_optional_keys(self) -> bool: + return any(key not in self.required_keys for key in self.items) + def is_anonymous(self) -> bool: return self.fallback.type.fullname() == 'typing.Mapping' diff --git a/mypy/typestate.py b/mypy/typestate.py index 337aac21d714..27a3c98a1ccc 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -3,7 +3,7 @@ and potentially other mutable TypeInfo state. This module contains mutable global state. """ -from typing import Dict, Set, Tuple, Optional +from typing import Any, Dict, Set, Tuple, Optional MYPY = False if MYPY: @@ -12,6 +12,17 @@ from mypy.types import Instance from mypy.server.trigger import make_trigger +# Represents that the 'left' instance is a subtype of the 'right' instance +SubtypeRelationship = Tuple[Instance, Instance] + +# A hash encoding the specific conditions under which we performed the subtype check. +# (e.g. did we want a proper subtype? A regular subtype while ignoring variance?) +SubtypeKind = int + +# A cache that keeps track of whether the given TypeInfo is a part of a particular +# subtype relationship +SubtypeCache = Dict[TypeInfo, Dict[SubtypeKind, Set[SubtypeRelationship]]] + class TypeState: """This class provides subtype caching to improve performance of subtype checks. @@ -23,13 +34,11 @@ class TypeState: The protocol dependencies however are only stored here, and shouldn't be deleted unless not needed any more (e.g. during daemon shutdown). """ - # 'caches' and 'caches_proper' are subtype caches, implemented as sets of pairs - # of (subtype, supertype), where supertypes are instances of given TypeInfo. + # '_subtype_caches' keeps track of (subtype, supertype) pairs where supertypes are + # instances of the given TypeInfo. The cache also keeps track of the specific + # *kind* of subtyping relationship, which we represent as an arbitrary hashable tuple. # We need the caches, since subtype checks for structural types are very slow. - # _subtype_caches_proper is for caching proper subtype checks (i.e. not assuming that - # Any is consistent with every type). - _subtype_caches = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] - _subtype_caches_proper = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] + _subtype_caches = {} # type: ClassVar[SubtypeCache] # This contains protocol dependencies generated after running a full build, # or after an update. These dependencies are special because: @@ -70,13 +79,11 @@ class TypeState: def reset_all_subtype_caches(cls) -> None: """Completely reset all known subtype caches.""" cls._subtype_caches = {} - cls._subtype_caches_proper = {} @classmethod def reset_subtype_caches_for(cls, info: TypeInfo) -> None: """Reset subtype caches (if any) for a given supertype TypeInfo.""" - cls._subtype_caches.setdefault(info, set()).clear() - cls._subtype_caches_proper.setdefault(info, set()).clear() + cls._subtype_caches.setdefault(info, dict()).clear() @classmethod def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: @@ -85,20 +92,15 @@ def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: cls.reset_subtype_caches_for(item) @classmethod - def is_cached_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches.setdefault(right.type, set()) - - @classmethod - def is_cached_proper_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches_proper.setdefault(right.type, set()) - - @classmethod - def record_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches.setdefault(right.type, set()).add((left, right)) + def is_cached_subtype_check(cls, kind: SubtypeKind, left: Instance, right: Instance) -> bool: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + return (left, right) in subtype_kinds.setdefault(kind, set()) @classmethod - def record_proper_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches_proper.setdefault(right.type, set()).add((left, right)) + def record_subtype_cache_entry(cls, kind: SubtypeKind, + left: Instance, right: Instance) -> None: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + subtype_kinds.setdefault(kind, set()).add((left, right)) @classmethod def reset_protocol_deps(cls) -> None: diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 8efa81be346c..8ffb516a8a2e 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -661,16 +661,16 @@ reveal_type(D.__lt__) # E: Revealed type is 'def [AT] (self: AT`1, other: AT`1) A() < A() B() < B() -A() < B() # E: Unsupported operand types for > ("B" and "A") +A() < B() # E: Unsupported operand types for < ("A" and "B") C() > A() C() > B() C() > C() -C() > D() # E: Unsupported operand types for < ("D" and "C") +C() > D() # E: Unsupported operand types for > ("C" and "D") D() >= A() -D() >= B() # E: Unsupported operand types for <= ("B" and "D") -D() >= C() # E: Unsupported operand types for <= ("C" and "D") +D() >= B() # E: Unsupported operand types for >= ("D" and "B") +D() >= C() # E: Unsupported operand types for >= ("D" and "C") D() >= D() A() <= 1 # E: Unsupported operand types for <= ("A" and "int") diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index dd2120f191e9..55cc0a0e6c21 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1599,6 +1599,33 @@ class A: class B(A): def __add__(self, x): pass +[case testOperatorMethodAgainstSameType] +class A: + def __add__(self, x: int) -> 'A': + if isinstance(x, int): + return A() + else: + return NotImplemented + + def __radd__(self, x: 'A') -> 'A': + if isinstance(x, A): + return A() + else: + return NotImplemented + +class B(A): pass + +# Note: This is a runtime error. If we run x.__add__(y) +# where x and y are *not* the same type, Python will not try +# calling __radd__. +A() + A() # E: Unsupported operand types for + ("A" and "A") \ + # N: __radd__ will not be called when running 'A + A': must define __add__ + +# Here, Python *will* call __radd__(...) +reveal_type(B() + A()) # E: Revealed type is '__main__.A' +reveal_type(A() + B()) # E: Revealed type is '__main__.A' +[builtins fixtures/isinstance.pyi] + [case testOperatorMethodOverrideWithIdenticalOverloadedType] from foo import * [file foo.pyi] @@ -1755,20 +1782,69 @@ class B: def __radd__(*self) -> int: pass def __rsub__(*self: 'B') -> int: pass -[case testReverseOperatorTypeVar] +[case testReverseOperatorTypeVar1] +from typing import TypeVar, Any +T = TypeVar("T", bound='Real') +class Real: + def __add__(self, other: Any) -> str: ... +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping + +# Note: When doing A + B and if B is a subtype of A, we will always call B.__radd__(A) first +# and only try A.__add__(B) second if necessary. +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' + +# Note: When doing A + A, we only ever call A.__add__(A), never A.__radd__(A). +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + +[case testReverseOperatorTypeVar2a] from typing import TypeVar T = TypeVar("T", bound='Real') class Real: - def __add__(self, other) -> str: ... + def __add__(self, other: Fraction) -> str: ... class Fraction(Real): def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + + +[case testReverseOperatorTypeVar2b] +from typing import TypeVar +T = TypeVar("T", Real, Fraction) +class Real: + def __add__(self, other: Fraction) -> str: ... +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "Real" are unsafely overlapping + +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + +[case testReverseOperatorTypeVar3] +from typing import TypeVar, Any +T = TypeVar("T", bound='Real') +class Real: + def __add__(self, other: FractionChild) -> str: ... +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping +class FractionChild(Fraction): pass + +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' +reveal_type(FractionChild() + Fraction()) # E: Revealed type is '__main__.FractionChild*' +reveal_type(FractionChild() + FractionChild()) # E: Revealed type is 'builtins.str' + +# Runtime error: we try calling __add__, it doesn't match, and we don't try __radd__ since +# the LHS and the RHS are not the same. +Fraction() + Fraction() # E: Unsupported operand types for + ("Fraction" and "Fraction") \ + # N: __radd__ will not be called when running 'Fraction + Fraction': must define __add__ + [case testReverseOperatorTypeType] from typing import TypeVar, Type class Real(type): - def __add__(self, other) -> str: ... + def __add__(self, other: FractionChild) -> str: ... class Fraction(Real): def __radd__(self, other: Type['A']) -> Real: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "Type[A]" are unsafely overlapping +class FractionChild(Fraction): pass class A(metaclass=Real): pass @@ -1811,7 +1887,7 @@ class B: @overload def __radd__(self, x: A) -> str: pass # Error class X: - def __add__(self, x): pass + def __add__(self, x: B) -> int: pass [out] tmp/foo.pyi:6: error: Signatures of "__radd__" of "B" and "__add__" of "X" are unsafely overlapping diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 93023dbb3ac4..fd2bc496deb1 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -537,9 +537,9 @@ class B: def __gt__(self, o: 'B') -> bool: pass [builtins fixtures/bool.pyi] [out] -main:3: error: Unsupported operand types for > ("A" and "A") -main:5: error: Unsupported operand types for > ("A" and "A") +main:3: error: Unsupported operand types for < ("A" and "A") main:5: error: Unsupported operand types for < ("A" and "A") +main:5: error: Unsupported operand types for > ("A" and "A") [case testChainedCompBoolRes] @@ -664,7 +664,7 @@ A() + cast(Any, 1) class C: def __gt__(self, x: 'A') -> object: pass class A: - def __lt__(self, x: C) -> int: pass + def __lt__(self, x: C) -> int: pass # E: Signatures of "__lt__" of "A" and "__gt__" of "C" are unsafely overlapping class B: def __gt__(self, x: A) -> str: pass s = None # type: str diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 6d76dcdd4605..7157c0ef58a4 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2005,3 +2005,81 @@ def f(x: Union[A, str]) -> None: if isinstance(x, A): x.method_only_in_a() [builtins fixtures/isinstance.pyi] + +[case testIsInstanceInitialNoneCheckSkipsImpossibleCasesNoStrictOptional] +# flags: --strict-optional +from typing import Optional, Union + +class A: pass + +def foo1(x: Union[A, str, None]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'None' + elif isinstance(x, A): + reveal_type(x) # E: Revealed type is '__main__.A' + else: + reveal_type(x) # E: Revealed type is 'builtins.str' + +def foo2(x: Optional[str]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'None' + elif isinstance(x, A): + reveal_type(x) + else: + reveal_type(x) # E: Revealed type is 'builtins.str' +[builtins fixtures/isinstance.pyi] + +[case testIsInstanceInitialNoneCheckSkipsImpossibleCasesInNoStrictOptional] +# flags: --no-strict-optional +from typing import Optional, Union + +class A: pass + +def foo1(x: Union[A, str, None]) -> None: + if x is None: + # Since None is a subtype of all types in no-strict-optional, + # we can't really narrow the type here + reveal_type(x) # E: Revealed type is 'Union[__main__.A, builtins.str, None]' + elif isinstance(x, A): + # Note that Union[None, A] == A in no-strict-optional + reveal_type(x) # E: Revealed type is '__main__.A' + else: + reveal_type(x) # E: Revealed type is 'builtins.str' + +def foo2(x: Optional[str]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'Union[builtins.str, None]' + elif isinstance(x, A): + # Mypy should, however, be able to skip impossible cases + reveal_type(x) + else: + reveal_type(x) # E: Revealed type is 'Union[builtins.str, None]' +[builtins fixtures/isinstance.pyi] + +[case testNoneCheckDoesNotNarrowWhenUsingTypeVars] +# flags: --strict-optional +from typing import TypeVar + +T = TypeVar('T') + +def foo(x: T) -> T: + out = None + out = x + if out is None: + pass + return out +[builtins fixtures/isinstance.pyi] + +[case testNoneCheckDoesNotNarrowWhenUsingTypeVarsNoStrictOptional] +# flags: --no-strict-optional +from typing import TypeVar + +T = TypeVar('T') + +def foo(x: T) -> T: + out = None + out = x + if out is None: + pass + return out +[builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/check-namedtuple.test b/test-data/unit/check-namedtuple.test index cf1c3a31d8de..5a1f5869b568 100644 --- a/test-data/unit/check-namedtuple.test +++ b/test-data/unit/check-namedtuple.test @@ -685,7 +685,7 @@ my_eval(A([B(1), B(2)])) # OK from typing import NamedTuple class Real(NamedTuple): - def __sub__(self, other) -> str: return "" + def __sub__(self, other: Real) -> str: return "" class Fraction(Real): def __rsub__(self, other: Real) -> Real: return other # E: Signatures of "__rsub__" of "Fraction" and "__sub__" of "Real" are unsafely overlapping diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 160a1a84732c..cedb26f58065 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -334,8 +334,8 @@ def bar(x: Union[T, C]) -> Union[T, int]: [builtins fixtures/isinstancelist.pyi] -[case testTypeCheckOverloadImplementationTypeVarDifferingUsage] -from typing import overload, Union, List, TypeVar +[case testTypeCheckOverloadImplementationTypeVarDifferingUsage1] +from typing import overload, Union, List, TypeVar, Generic T = TypeVar('T') @@ -348,6 +348,50 @@ def foo(t: Union[List[T], T]) -> T: return t[0] else: return t + +class Wrapper(Generic[T]): + @overload + def foo(self, t: List[T]) -> T: ... + @overload + def foo(self, t: T) -> T: ... + def foo(self, t: Union[List[T], T]) -> T: + if isinstance(t, list): + return t[0] + else: + return t +[builtins fixtures/isinstancelist.pyi] + +[case testTypeCheckOverloadImplementationTypeVarDifferingUsage2] +from typing import overload, Union, List, TypeVar, Generic + +T = TypeVar('T') + +# Note: this is unsafe when T = object +@overload +def foo(t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def foo(t: T, s: T) -> str: ... +def foo(t, s): pass + +# TODO: Why are we getting a different error message here? +# Shouldn't we be getting the same error message? +class Wrapper(Generic[T]): + @overload + def foo(self, t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def foo(self, t: T, s: T) -> str: ... + def foo(self, t, s): pass + +class Dummy(Generic[T]): pass + +# Same root issue: why does the additional constraint bound T <: T +# cause the constraint solver to not infer T = object like it did in the +# first example? +@overload +def bar(d: Dummy[T], t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def bar(d: Dummy[T], t: T, s: T) -> str: ... +def bar(d: Dummy[T], t, s): pass [builtins fixtures/isinstancelist.pyi] [case testTypeCheckOverloadedFunctionBody] @@ -1524,14 +1568,25 @@ reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'Any' [case testOverloadWithOverlappingItemsAndAnyArgument5] from typing import overload, Any, Union +class A: pass +class B(A): pass + @overload -def f(x: int) -> int: ... +def f(x: B) -> B: ... @overload -def f(x: Union[int, float]) -> float: ... +def f(x: Union[A, B]) -> A: ... def f(x): pass +# Note: overloads ignore promotions so we treat 'int' and 'float' as distinct types +@overload +def g(x: int) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def g(x: Union[int, float]) -> float: ... +def g(x): pass + a: Any reveal_type(f(a)) # E: Revealed type is 'Any' +reveal_type(g(a)) # E: Revealed type is 'Any' [case testOverloadWithOverlappingItemsAndAnyArgument6] from typing import overload, Any @@ -1817,7 +1872,7 @@ def r(x: Any) -> Any:... @overload def g(x: A) -> A: ... @overload -def g(x: Tuple[A1, ...]) -> B: ... # E: Overloaded function signatures 2 and 3 overlap with incompatible return types +def g(x: Tuple[A1, ...]) -> B: ... # E: Overloaded function signatures 2 and 3 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def g(x: Tuple[A, A]) -> C: ... @overload @@ -2004,7 +2059,7 @@ def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and def foo(x: T, y: T) -> int: ... def foo(x): ... -# TODO: We should allow this; T can't be bound to two distinct types +# What if 'T' is 'object'? @overload def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types @overload @@ -2012,17 +2067,17 @@ def bar(x: T, y: T) -> int: ... def bar(x, y): ... class Wrapper(Generic[T]): - # TODO: This should be an error + # TODO: Why do these have different error messages? @overload - def foo(self, x: None, y: None) -> str: ... + def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload - def foo(self, x: T, y: None) -> str: ... + def foo(self, x: T, y: None) -> int: ... def foo(self, x): ... @overload - def bar(self, x: None, y: int) -> str: ... + def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload - def bar(self, x: T, y: T) -> str: ... + def bar(self, x: T, y: T) -> int: ... def bar(self, x, y): ... [case testOverloadFlagsPossibleMatches] @@ -2526,7 +2581,7 @@ class C: ... class D: ... @overload -def f(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def f(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def f(x: Union[B, C]) -> str: ... def f(x): ... @@ -2534,26 +2589,366 @@ def f(x): ... @overload def g(x: Union[A, B]) -> int: ... @overload -def g(x: Union[C, D]) -> str: ... +def g(x: Union[B, C]) -> int: ... def g(x): ... @overload -def h(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def h(x: Union[A, B]) -> int: ... @overload -def h(x: Union[A, B, C]) -> str: ... +def h(x: Union[C, D]) -> str: ... def h(x): ... +@overload +def i(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def i(x: Union[A, B, C]) -> str: ... +def i(x): ... + +[case testOverloadWithPartiallyOverlappingUnionsNested] +from typing import overload, Union, List + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: List[Union[B, C]]) -> str: ... +def f(x): ... + +@overload +def g(x: List[Union[A, B]]) -> int: ... +@overload +def g(x: List[Union[B, C]]) -> int: ... +def g(x): ... + +@overload +def h(x: List[Union[A, B]]) -> int: ... +@overload +def h(x: List[Union[C, D]]) -> str: ... +def h(x): ... + +@overload +def i(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def i(x: List[Union[A, B, C]]) -> str: ... +def i(x): ... + +[builtins fixtures/list.pyi] + [case testOverloadPartialOverlapWithUnrestrictedTypeVar] from typing import TypeVar, overload T = TypeVar('T') @overload -def f(x: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def f(x: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def f(x: T) -> T: ... def f(x): ... +@overload +def g(x: int) -> int: ... +@overload +def g(x: T) -> T: ... +def g(x): ... + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarNested] +from typing import TypeVar, overload, List + +T = TypeVar('T') + +@overload +def f1(x: List[int]) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f1(x: List[T]) -> T: ... +def f1(x): ... + +@overload +def f2(x: List[int]) -> List[str]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f2(x: List[T]) -> List[T]: ... +def f2(x): ... + +@overload +def g1(x: List[int]) -> int: ... +@overload +def g1(x: List[T]) -> T: ... +def g1(x): ... + +@overload +def g2(x: List[int]) -> List[int]: ... +@overload +def g2(x: List[T]) -> List[T]: ... +def g2(x): ... + +[builtins fixtures/list.pyi] + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarInClass] +from typing import TypeVar, overload, Generic + +T = TypeVar('T') + +class Wrapper(Generic[T]): + @overload + def f(self, x: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f(self, x: T) -> T: ... + def f(self, x): ... + + # TODO: This shouldn't trigger an error message. + # Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2? + @overload + def g(self, x: int) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g(self, x: T) -> T: ... + def g(self, x): ... + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarInClassNested] +from typing import TypeVar, overload, Generic, List + +T = TypeVar('T') + +class Wrapper(Generic[T]): + @overload + def f1(self, x: List[int]) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f1(self, x: List[T]) -> T: ... + def f1(self, x): ... + + @overload + def f2(self, x: List[int]) -> List[str]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f2(self, x: List[T]) -> List[T]: ... + def f2(self, x): ... + + # TODO: This shouldn't trigger an error message. + # Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2? + @overload + def g1(self, x: List[int]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g1(self, x: List[T]) -> T: ... + def g1(self, x): ... + + @overload + def g2(self, x: List[int]) -> List[int]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g2(self, x: List[T]) -> List[T]: ... + def g2(self, x): ... + +[builtins fixtures/list.pyi] + +[case testOverloadTypedDictDifferentRequiredKeysMeansDictsAreDisjoint] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': int}) +B = TypedDict('B', {'x': int, 'y': str}) + +@overload +def f(x: A) -> int: ... +@overload +def f(x: B) -> str: ... +def f(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictPartiallyOverlappingRequiredKeys] +from typing import overload, Union +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': Union[int, str]}) +B = TypedDict('B', {'x': int, 'y': Union[str, float]}) + +@overload +def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: B) -> str: ... +def f(x): pass + +@overload +def g(x: A) -> int: ... +@overload +def g(x: B) -> object: ... +def g(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictFullyNonTotalDictsAreAlwaysPartiallyOverlapping] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': str}, total=False) +B = TypedDict('B', {'a': bool}, total=False) +C = TypedDict('C', {'x': str, 'y': int}, total=False) + +@overload +def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: B) -> str: ... +def f(x): pass + +@overload +def g(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def g(x: C) -> str: ... +def g(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTotalAndNonTotalTypedDictsCanPartiallyOverlap] +from typing import overload, Union +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': str}) +B = TypedDict('B', {'x': Union[int, str], 'y': str, 'z': int}, total=False) + +@overload +def f1(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f1(x: B) -> str: ... +def f1(x): pass + +@overload +def f2(x: B) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f2(x: A) -> str: ... +def f2(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictsWithSomeOptionalKeysArePartiallyOverlapping] +from typing import overload, Union +from mypy_extensions import TypedDict + +class A(TypedDict): + x: int + y: int + +class B(TypedDict, total=False): + z: str + +class C(TypedDict, total=False): + z: int + +@overload +def f(x: B) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: C) -> str: ... +def f(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes1] +from typing import overload, List, Union, TypeVar, Generic + +class A: pass +class B: pass +class C: pass + +T = TypeVar('T') + +class ListSubclass(List[T]): pass +class Unrelated(Generic[T]): pass + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: ListSubclass[Union[B, C]]) -> str: ... +def f(x): pass + +@overload +def g(x: List[Union[A, B]]) -> int: ... +@overload +def g(x: Unrelated[Union[B, C]]) -> str: ... +def g(x): pass + +[builtins fixtures/list.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes2] +from typing import overload, List, Union + +class A: pass +class B: pass +class C: pass + +class ListSubclass(List[Union[B, C]]): pass + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: ListSubclass) -> str: ... +def f(x): pass + +[builtins fixtures/list.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes3] +from typing import overload, Union, Dict, TypeVar + +class A: pass +class B: pass +class C: pass + +S = TypeVar('S') + +class DictSubclass(Dict[str, S]): pass + +@overload +def f(x: Dict[str, Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: DictSubclass[Union[B, C]]) -> str: ... +def f(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedPartiallyOverlappingTypeVarsAndUnion] +from typing import overload, TypeVar, Union + +class A: pass +class B: pass +class C: pass + +S = TypeVar('S', A, B) + +@overload +def f(x: S) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: Union[B, C]) -> str: ... +def f(x): pass + +@overload +def g(x: Union[B, C]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def g(x: S) -> str: ... +def g(x): pass + +[case testOverloadPartiallyOverlappingTypeVarsIdentical] +from typing import overload, TypeVar, Union + +T = TypeVar('T') + +class A: pass +class B: pass +class C: pass + +@overload +def f(x: T, y: T, z: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: T, y: T, z: Union[B, C]) -> str: ... +def f(x, y, z): pass + +[case testOverloadedPartiallyOverlappingCallables] +from typing import overload, Union, Callable + +class A: pass +class B: pass +class C: pass + +@overload +def f(x: Callable[[Union[A, B]], int]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: Callable[[Union[B, C]], int]) -> str: ... +def f(x): pass + [case testOverloadNotConfusedForProperty] from typing import overload @@ -3444,7 +3839,7 @@ T = TypeVar('T') class FakeAttribute(Generic[T]): @overload - def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def dummy(self, instance: T, owner: Type[T]) -> int: ... def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ... @@ -4186,3 +4581,54 @@ def g(x: str) -> int: ... [builtins fixtures/list.pyi] [typing fixtures/typing-full.pyi] [out] + +[case testOverloadsIgnorePromotions] +from typing import overload, List, Union, _promote + +class Parent: pass +class Child(Parent): pass + +children: List[Child] +parents: List[Parent] + +@overload +def f(x: Child) -> List[Child]: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Parent) -> List[Parent]: pass +def f(x: Union[Child, Parent]) -> Union[List[Child], List[Parent]]: + if isinstance(x, Child): + reveal_type(x) # E: Revealed type is '__main__.Child' + return children + else: + reveal_type(x) # E: Revealed type is '__main__.Parent' + return parents + +ints: List[int] +floats: List[float] + +@overload +def g(x: int) -> List[int]: pass +@overload +def g(x: float) -> List[float]: pass +def g(x: Union[int, float]) -> Union[List[int], List[float]]: + if isinstance(x, int): + reveal_type(x) # E: Revealed type is 'builtins.int' + return ints + else: + reveal_type(x) # E: Revealed type is 'builtins.float' + return floats + +[builtins fixtures/isinstancelist.pyi] + +[case testOverloadsTypesAndUnions] +from typing import overload, Type, Union + +class A: pass +class B: pass + +@overload +def f(x: Type[A]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Union[Type[A], Type[B]]) -> str: ... +def f(x: Union[Type[A], Type[B]]) -> Union[int, str]: + return 1 diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index df8bc6548f14..850ec9ba6f38 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -1578,7 +1578,7 @@ d = {'weight0': 65.5} reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' d['weight0'] = 65 reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' -d['weight0'] *= 'a' # E: Unsupported operand types for * ("float" and "str") # E: Incompatible types in assignment (expression has type "str", target has type "float") +d['weight0'] *= 'a' # E: Unsupported operand types for * ("float" and "str") d['weight0'] *= 0.5 reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' d['weight0'] *= object() # E: Unsupported operand types for * ("float" and "object") diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index b86154302f8c..988038264d54 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -564,3 +564,43 @@ if typing.TYPE_CHECKING: reveal_type(x) # E: Revealed type is '__main__.B' [builtins fixtures/isinstancelist.pyi] + +[case testUnreachableWhenSuperclassIsAny] +# flags: --strict-optional +from typing import Any + +# This can happen if we're importing a class from a missing module +Parent: Any +class Child(Parent): + def foo(self) -> int: + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 + + def bar(self) -> int: + self = super(Child, self).something() + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 +[builtins fixtures/isinstance.pyi] + +[case testUnreachableWhenSuperclassIsAnyNoStrictOptional] +# flags: --no-strict-optional +from typing import Any + +Parent: Any +class Child(Parent): + def foo(self) -> int: + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) # E: Revealed type is '__main__.Child' + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 +[builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/fixtures/isinstance.pyi b/test-data/unit/fixtures/isinstance.pyi index ded946ce73fe..35535b9a588f 100644 --- a/test-data/unit/fixtures/isinstance.pyi +++ b/test-data/unit/fixtures/isinstance.pyi @@ -1,4 +1,4 @@ -from typing import Tuple, TypeVar, Generic, Union +from typing import Tuple, TypeVar, Generic, Union, cast, Any T = TypeVar('T') @@ -22,3 +22,5 @@ class bool(int): pass class str: def __add__(self, other: 'str') -> 'str': pass class ellipsis: pass + +NotImplemented = cast(Any, None) diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 99aca1befe39..1831411319ef 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -14,6 +14,7 @@ def issubclass(x: object, t: Union[type, Tuple]) -> bool: pass class int: def __add__(self, x: int) -> int: pass +class float: pass class bool(int): pass class str: def __add__(self, x: str) -> str: pass diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 606b2bd47e01..446cb0f697fd 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -428,10 +428,10 @@ b'' < '' '' < bytearray() bytearray() < '' [out] -_program.py:2: error: Unsupported operand types for > ("bytes" and "str") -_program.py:3: error: Unsupported operand types for > ("str" and "bytes") -_program.py:4: error: Unsupported operand types for > ("bytearray" and "str") -_program.py:5: error: Unsupported operand types for > ("str" and "bytearray") +_program.py:2: error: Unsupported operand types for < ("str" and "bytes") +_program.py:3: error: Unsupported operand types for < ("bytes" and "str") +_program.py:4: error: Unsupported operand types for < ("str" and "bytearray") +_program.py:5: error: Unsupported operand types for < ("bytearray" and "str") [case testInplaceOperatorMethod] import typing