Skip to content

Fix strict equality with enum type with custom __eq__ #14518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
not local_errors.has_new_errors()
and cont_type
and self.dangerous_comparison(
left_type, cont_type, original_container=right_type
left_type, cont_type, original_container=right_type, prefer_literal=False
)
):
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
Expand All @@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if not w.has_new_errors() and operator in ("==", "!="):
right_type = self.accept(right)
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
# Show the most specific literal types possible
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "equality", e)

elif operator == "is" or operator == "is not":
right_type = self.accept(right) # validate the right operand
sub_result = self.bool_type()
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
# Show the most specific literal types possible
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
method_type = None
else:
Expand Down Expand Up @@ -3036,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
return None

def dangerous_comparison(
self, left: Type, right: Type, original_container: Type | None = None
self,
left: Type,
right: Type,
original_container: Type | None = None,
*,
prefer_literal: bool = True,
) -> bool:
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.

Expand Down Expand Up @@ -3064,6 +3067,14 @@ def dangerous_comparison(
if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"):
return False

if prefer_literal:
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left = try_getting_literal(left)
right = try_getting_literal(right)

if self.chk.binder.is_unreachable_warning_suppressed():
# We are inside a function that contains type variables with value restrictions in
# its signature. In this case we just suppress all strict-equality checks to avoid
Expand Down
26 changes: 26 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,32 @@ int == y
y == int
[builtins fixtures/bool.pyi]

[case testStrictEqualityAndEnumWithCustomEq]
# flags: --strict-equality
from enum import Enum

class E1(Enum):
X = 0
Y = 1

class E2(Enum):
X = 0
Y = 1

def __eq__(self, other: object) -> bool:
return bool()

E1.X == E1.Y # E: Non-overlapping equality check (left operand type: "Literal[E1.X]", right operand type: "Literal[E1.Y]")
E2.X == E2.Y
[builtins fixtures/bool.pyi]

[case testStrictEqualityWithBytesContains]
# flags: --strict-equality
data = b"xy"
b"x" in data
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testUnimportedHintAny]
def f(x: Any) -> None: # E: Name "Any" is not defined \
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
Expand Down