diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 702c6cb4f58d..599d67af4899 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -286,8 +286,8 @@ def check_call(self, callee: Type, args: List[Node], callee) elif isinstance(callee, Instance): call_function = analyze_member_access('__call__', callee, context, - False, False, self.named_type, self.not_ready_callback, - self.msg) + False, False, False, self.named_type, + self.not_ready_callback, self.msg) return self.check_call(call_function, args, arg_kinds, context, arg_names, callable_node, arg_messages) elif isinstance(callee, TypeVarType): @@ -861,7 +861,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, else: # This is a reference to a non-module attribute. return analyze_member_access(e.name, self.accept(e.expr), e, - is_lvalue, False, + is_lvalue, False, False, self.named_type, self.not_ready_callback, self.msg) def analyze_external_member_access(self, member: str, base_type: Type, @@ -870,7 +870,7 @@ def analyze_external_member_access(self, member: str, base_type: Type, refer to private definitions. Return the result type. """ # TODO remove; no private definitions in mypy - return analyze_member_access(member, base_type, context, False, False, + return analyze_member_access(member, base_type, context, False, False, False, self.named_type, self.not_ready_callback, self.msg) def visit_int_expr(self, e: IntExpr) -> Type: @@ -1008,7 +1008,7 @@ def check_op_local(self, method: str, base_type: Type, arg: Node, Return tuple (result type, inferred operator method type). """ - method_type = analyze_member_access(method, base_type, context, False, False, + method_type = analyze_member_access(method, base_type, context, False, False, True, self.named_type, self.not_ready_callback, local_errors) return self.check_call(method_type, [arg], [nodes.ARG_POS], context, arg_messages=local_errors) @@ -1434,7 +1434,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type: if not self.chk.typing_mode_full(): return AnyType() return analyze_member_access(e.name, self_type(e.info), e, - is_lvalue, True, + is_lvalue, True, False, self.named_type, self.not_ready_callback, self.msg, base) else: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index a4a53f66dae0..36768e0830ea 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -8,7 +8,8 @@ DeletedType, NoneTyp, TypeType ) from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context -from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, function_type, Decorator, OverloadedFuncDef +from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, OpExpr, ComparisonExpr +from mypy.nodes import function_type, Decorator, OverloadedFuncDef from mypy.messages import MessageBuilder from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance @@ -23,6 +24,7 @@ def analyze_member_access(name: str, node: Context, is_lvalue: bool, is_super: bool, + is_operator: bool, builtin_type: Callable[[str], Instance], not_ready_callback: Callable[[str, Context], None], msg: MessageBuilder, @@ -79,20 +81,20 @@ def analyze_member_access(name: str, elif isinstance(typ, NoneTyp): # The only attribute NoneType has are those it inherits from object return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue, - is_super, builtin_type, not_ready_callback, msg, + is_super, is_operator, builtin_type, not_ready_callback, msg, report_type=report_type) elif isinstance(typ, UnionType): # The base object has dynamic type. msg.disable_type_names += 1 - results = [analyze_member_access(name, subtype, node, is_lvalue, - is_super, builtin_type, not_ready_callback, msg) + results = [analyze_member_access(name, subtype, node, is_lvalue, is_super, + is_operator, builtin_type, not_ready_callback, msg) for subtype in typ.items] msg.disable_type_names -= 1 return UnionType.make_simplified_union(results) elif isinstance(typ, TupleType): # Actually look up from the fallback instance type. - return analyze_member_access(name, typ.fallback, node, is_lvalue, - is_super, builtin_type, not_ready_callback, msg) + return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, + is_operator, builtin_type, not_ready_callback, msg) elif isinstance(typ, FunctionLike) and typ.is_type_obj(): # Class attribute. # TODO super? @@ -100,24 +102,38 @@ def analyze_member_access(name: str, if isinstance(ret_type, TupleType): ret_type = ret_type.fallback if isinstance(ret_type, Instance): - result = analyze_class_attribute_access(ret_type, name, node, is_lvalue, - builtin_type, not_ready_callback, msg) - if result: - return result + if not is_operator: + # When Python sees an operator (eg `3 == 4`), it automatically translates that + # into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an + # optimation. + # + # While it normally it doesn't matter which of the two versions are used, it + # does cause inconsistencies when working with classes. For example, translating + # `int == int` to `int.__eq__(int)` would not work since `int.__eq__` is meant to + # compare two int _instances_. What we really want is `type(int).__eq__`, which + # is meant to compare two types or classes. + # + # This check makes sure that when we encounter an operator, we skip looking up + # the corresponding method in the current instance to avoid this edge case. + # See https://github.com/python/mypy/pull/1787 for more info. + result = analyze_class_attribute_access(ret_type, name, node, is_lvalue, + builtin_type, not_ready_callback, msg) + if result: + return result # Look up from the 'type' type. return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, - builtin_type, not_ready_callback, msg, + is_operator, builtin_type, not_ready_callback, msg, report_type=report_type) else: assert False, 'Unexpected type {}'.format(repr(ret_type)) elif isinstance(typ, FunctionLike): # Look up from the 'function' type. return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, - builtin_type, not_ready_callback, msg, + is_operator, builtin_type, not_ready_callback, msg, report_type=report_type) elif isinstance(typ, TypeVarType): return analyze_member_access(name, typ.upper_bound, node, is_lvalue, is_super, - builtin_type, not_ready_callback, msg, + is_operator, builtin_type, not_ready_callback, msg, report_type=report_type) elif isinstance(typ, DeletedType): msg.deleted_as_rvalue(typ, node) @@ -130,14 +146,15 @@ def analyze_member_access(name: str, elif isinstance(typ.item, TypeVarType): if isinstance(typ.item.upper_bound, Instance): item = typ.item.upper_bound - if item: + if item and not is_operator: + # See comment above for why operators are skipped result = analyze_class_attribute_access(item, name, node, is_lvalue, builtin_type, not_ready_callback, msg) if result: return result fallback = builtin_type('builtins.type') return analyze_member_access(name, fallback, node, is_lvalue, is_super, - builtin_type, not_ready_callback, msg, + is_operator, builtin_type, not_ready_callback, msg, report_type=report_type) return msg.has_no_attr(report_type, name, node) diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 513478259d54..5670f9cd346a 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -2006,3 +2006,44 @@ reveal_type(User) # E: Revealed type is 'builtins.type' [builtins fixtures/args.py] [out] +[case testTypeTypeComparisonWorks] +class User: pass + +User == User +User == type(User()) +type(User()) == User +type(User()) == type(User()) + +User != User +User != type(User()) +type(User()) != User +type(User()) != type(User()) + +int == int +int == type(3) +type(3) == int +type(3) == type(3) + +int != int +int != type(3) +type(3) != int +type(3) != type(3) + +User is User +User is type(User) +type(User) is User +type(User) is type(User) + +int is int +int is type(3) +type(3) is int +type(3) is type(3) + +int.__eq__(int) +int.__eq__(3, 4) +[builtins fixtures/args.py] +[out] +main:33: error: Too few arguments for "__eq__" of "int" +main:33: error: Unsupported operand types for == ("int" and "int") + + diff --git a/test-data/unit/fixtures/args.py b/test-data/unit/fixtures/args.py index dd7d9ea61c6f..b084fc6c68e5 100644 --- a/test-data/unit/fixtures/args.py +++ b/test-data/unit/fixtures/args.py @@ -8,17 +8,21 @@ class object: def __init__(self) -> None: pass + def __eq__(self, o: object) -> bool: pass + def __ne__(self, o: object) -> bool: pass class type: @overload def __init__(self, o: object) -> None: pass @overload def __init__(self, name: str, bases: Tuple[type, ...], dict: Dict[str, Any]) -> None: pass + def __call__(self, *args: Any, **kwargs: Any) -> Any: pass class tuple(Iterable[Tco], Generic[Tco]): pass class dict(Generic[T, S]): pass -class int: pass +class int: + def __eq__(self, o: object) -> bool: pass class str: pass class bool: pass class function: pass