From 58d774392f460b90e68277cd2a900cf78a93c790 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 24 Jan 2023 15:53:25 +0000 Subject: [PATCH 1/2] Fix strict equality with enum type with custom __eq__ Fixes regression introduced in #14513. --- mypy/checkexpr.py | 21 +++++++++++++-------- test-data/unit/check-expressions.test | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8dea7d0e8551..a60d60e3cd82 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if not w.has_new_errors() and operator in ("==", "!="): right_type = self.accept(right) - # Also flag non-overlapping literals in situations like: - # x: Literal['a', 'b'] - # if x == 'c': - # ... - left_type = try_getting_literal(left_type) - right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): + # Show the most specific literal types possible + left_type = try_getting_literal(left_type) + right_type = try_getting_literal(right_type) self.msg.dangerous_comparison(left_type, right_type, "equality", e) elif operator == "is" or operator == "is not": right_type = self.accept(right) # validate the right operand sub_result = self.bool_type() - left_type = try_getting_literal(left_type) - right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): + # Show the most specific literal types possible + left_type = try_getting_literal(left_type) + right_type = try_getting_literal(right_type) self.msg.dangerous_comparison(left_type, right_type, "identity", e) method_type = None else: @@ -3064,6 +3062,13 @@ def dangerous_comparison( if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"): return False + # Also flag non-overlapping literals in situations like: + # x: Literal['a', 'b'] + # if x == 'c': + # ... + left = try_getting_literal(left) + right = try_getting_literal(right) + if self.chk.binder.is_unreachable_warning_suppressed(): # We are inside a function that contains type variables with value restrictions in # its signature. In this case we just suppress all strict-equality checks to avoid diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 20ccbb17d5d5..a4ce93f07e82 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2221,6 +2221,25 @@ int == y y == int [builtins fixtures/bool.pyi] +[case testStrictEqualityAndEnumWithCustomEq] +# flags: --strict-equality +from enum import Enum + +class E1(Enum): + X = 0 + Y = 1 + +class E2(Enum): + X = 0 + Y = 1 + + def __eq__(self, other: object) -> bool: + return bool() + +E1.X == E1.Y # E: Non-overlapping equality check (left operand type: "Literal[E1.X]", right operand type: "Literal[E1.Y]") +E2.X == E2.Y +[builtins fixtures/bool.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any") From 7b334480a995dc058cc7f1c90a6db6e261a2b223 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 24 Jan 2023 16:39:11 +0000 Subject: [PATCH 2/2] Fix bytes contains --- mypy/checkexpr.py | 22 ++++++++++++++-------- test-data/unit/check-expressions.test | 7 +++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a60d60e3cd82..e19d48f4f5e7 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: not local_errors.has_new_errors() and cont_type and self.dangerous_comparison( - left_type, cont_type, original_container=right_type + left_type, cont_type, original_container=right_type, prefer_literal=False ) ): self.msg.dangerous_comparison(left_type, cont_type, "container", e) @@ -3034,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None: return None def dangerous_comparison( - self, left: Type, right: Type, original_container: Type | None = None + self, + left: Type, + right: Type, + original_container: Type | None = None, + *, + prefer_literal: bool = True, ) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. @@ -3062,12 +3067,13 @@ def dangerous_comparison( if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"): return False - # Also flag non-overlapping literals in situations like: - # x: Literal['a', 'b'] - # if x == 'c': - # ... - left = try_getting_literal(left) - right = try_getting_literal(right) + if prefer_literal: + # Also flag non-overlapping literals in situations like: + # x: Literal['a', 'b'] + # if x == 'c': + # ... + left = try_getting_literal(left) + right = try_getting_literal(right) if self.chk.binder.is_unreachable_warning_suppressed(): # We are inside a function that contains type variables with value restrictions in diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index a4ce93f07e82..49a3f0d4aaa7 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2240,6 +2240,13 @@ E1.X == E1.Y # E: Non-overlapping equality check (left operand type: "Literal[E E2.X == E2.Y [builtins fixtures/bool.pyi] +[case testStrictEqualityWithBytesContains] +# flags: --strict-equality +data = b"xy" +b"x" in data +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")