Skip to content

Commit 4ff341f

Browse files
ilevkivskyimsullivan
authored andcommitted
Make --strict-equality stricter with literals (#7310)
1 parent 4952275 commit 4ff341f

File tree

4 files changed

+41
-10
lines changed

4 files changed

+41
-10
lines changed

mypy/checkexpr.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,17 +1988,25 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19881988
# testCustomEqCheckStrictEquality for an example.
19891989
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
19901990
right_type = self.accept(right)
1991+
# We suppress the error if there is a custom __eq__() method on either
1992+
# side. User defined (or even standard library) classes can define this
1993+
# to return True for comparisons between non-overlapping types.
19911994
if (not custom_equality_method(left_type) and
19921995
not custom_equality_method(right_type)):
1993-
# We suppress the error if there is a custom __eq__() method on either
1994-
# side. User defined (or even standard library) classes can define this
1995-
# to return True for comparisons between non-overlapping types.
1996+
# Also flag non-overlapping literals in situations like:
1997+
# x: Literal['a', 'b']
1998+
# if x == 'c':
1999+
# ...
2000+
left_type = try_getting_literal(left_type)
2001+
right_type = try_getting_literal(right_type)
19962002
if self.dangerous_comparison(left_type, right_type):
19972003
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
19982004

19992005
elif operator == 'is' or operator == 'is not':
20002006
right_type = self.accept(right) # validate the right operand
20012007
sub_result = self.bool_type()
2008+
left_type = try_getting_literal(left_type)
2009+
right_type = try_getting_literal(right_type)
20022010
if self.dangerous_comparison(left_type, right_type):
20032011
self.msg.dangerous_comparison(left_type, right_type, 'identity', e)
20042012
method_type = None
@@ -4017,6 +4025,13 @@ def is_literal_type_like(t: Optional[Type]) -> bool:
40174025
return False
40184026

40194027

4028+
def try_getting_literal(typ: Type) -> Type:
4029+
"""If possible, get a more precise literal type for a given type."""
4030+
if isinstance(typ, Instance) and typ.last_known_value is not None:
4031+
return typ.last_known_value
4032+
return typ
4033+
4034+
40204035
def is_expr_literal_type(node: Expression) -> bool:
40214036
"""Returns 'true' if the given node is a Literal"""
40224037
valid = ('typing.Literal', 'typing_extensions.Literal')

test-data/unit/check-columns.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ if int():
226226

227227
[case testColumnNonOverlappingEqualityCheck]
228228
# flags: --strict-equality
229-
if 1 == '': # E:4: Non-overlapping equality check (left operand type: "int", right operand type: "str")
229+
if 1 == '': # E:4: Non-overlapping equality check (left operand type: "Literal[1]", right operand type: "Literal['']")
230230
pass
231231
[builtins fixtures/bool.pyi]
232232

test-data/unit/check-expressions.test

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,11 +2029,11 @@ class B: ...
20292029
a: Union[int, str]
20302030
b: Union[A, B]
20312031

2032-
a == 42
2033-
b == 42 # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int")
2032+
a == int()
2033+
b == int() # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int")
20342034

2035-
a is 42
2036-
b is 42 # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int")
2035+
a is int()
2036+
b is int() # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int")
20372037

20382038
ca: Union[Container[int], Container[str]]
20392039
cb: Union[Container[A], Container[B]]
@@ -2061,7 +2061,7 @@ x in b'abc'
20612061

20622062
[case testStrictEqualityNoPromotePy3]
20632063
# flags: --strict-equality
2064-
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
2064+
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")
20652065
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")
20662066

20672067
x: str
@@ -2271,6 +2271,22 @@ def f(x: T) -> T:
22712271
return x
22722272
[builtins fixtures/bool.pyi]
22732273

2274+
[case testStrictEqualityWithALiteral]
2275+
# flags: --strict-equality
2276+
from typing_extensions import Literal, Final
2277+
2278+
def returns_a_or_b() -> Literal['a', 'b']:
2279+
...
2280+
def returns_1_or_2() -> Literal[1, 2]:
2281+
...
2282+
THREE: Final = 3
2283+
2284+
if returns_a_or_b() == 'c': # E: Non-overlapping equality check (left operand type: "Union[Literal['a'], Literal['b']]", right operand type: "Literal['c']")
2285+
...
2286+
if returns_1_or_2() is THREE: # E: Non-overlapping identity check (left operand type: "Union[Literal[1], Literal[2]]", right operand type: "Literal[3]")
2287+
...
2288+
[builtins fixtures/bool.pyi]
2289+
22742290
[case testUnimportedHintAny]
22752291
def f(x: Any) -> None: # E: Name 'Any' is not defined \
22762292
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")

test-data/unit/check-flags.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def f(c: A) -> None: # E: Missing type parameters for generic type "A"
11251125
[case testStrictEqualityPerFile]
11261126
# flags: --config-file tmp/mypy.ini
11271127
import b
1128-
42 == 'no' # E: Non-overlapping equality check (left operand type: "int", right operand type: "str")
1128+
42 == 'no' # E: Non-overlapping equality check (left operand type: "Literal[42]", right operand type: "Literal['no']")
11291129
[file b.py]
11301130
42 == 'no'
11311131
[file mypy.ini]

0 commit comments

Comments
 (0)