Skip to content

Commit 8e960b3

Browse files
authored
Apply --strict-equality special-casing for bytes also to bytearray (#7473)
Fixes #7465
1 parent 6dce58f commit 8e960b3

File tree

4 files changed

+19
-6
lines changed

4 files changed

+19
-6
lines changed

mypy/checkexpr.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,8 +2103,9 @@ def dangerous_comparison(self, left: Type, right: Type,
21032103
right = remove_optional(right)
21042104
if (original_container and has_bytes_component(original_container) and
21052105
has_bytes_component(left)):
2106-
# We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
2107-
# return True (and we want to show the error only if the check can _never_ be True).
2106+
# We need to special case bytes and bytearray, because 97 in b'abc', b'a' in b'abc',
2107+
# b'a' in bytearray(b'abc') etc. all return True (and we want to show the error only
2108+
# if the check can _never_ be True).
21082109
return False
21092110
if isinstance(left, Instance) and isinstance(right, Instance):
21102111
# Special case some builtin implementations of AbstractSet.
@@ -4136,11 +4137,12 @@ def custom_equality_method(typ: Type) -> bool:
41364137

41374138

41384139
def has_bytes_component(typ: Type) -> bool:
4139-
"""Is this the builtin bytes type, or a union that contains it?"""
4140+
"""Is this one of builtin byte types, or a union that contains it?"""
41404141
typ = get_proper_type(typ)
41414142
if isinstance(typ, UnionType):
41424143
return any(has_bytes_component(t) for t in typ.items)
4143-
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
4144+
if isinstance(typ, Instance) and typ.type.fullname() in {'builtins.bytes',
4145+
'builtins.bytearray'}:
41444146
return True
41454147
return False
41464148

test-data/unit/check-expressions.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,13 @@ x in b'abc'
20592059
[builtins fixtures/primitives.pyi]
20602060
[typing fixtures/typing-full.pyi]
20612061

2062+
[case testStrictEqualityByteArraySpecial]
2063+
# flags: --strict-equality
2064+
b'abc' in bytearray(b'abcde')
2065+
bytearray(b'abc') in b'abcde' # OK on Python 3
2066+
[builtins fixtures/primitives.pyi]
2067+
[typing fixtures/typing-full.pyi]
2068+
20622069
[case testStrictEqualityNoPromotePy3]
20632070
# flags: --strict-equality
20642071
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")

test-data/unit/check-type-promotion.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ f(1)
2323

2424
[case testPromoteBytearrayToByte]
2525
def f(x: bytes) -> None: pass
26-
f(bytearray())
26+
f(bytearray(b''))
2727
[builtins fixtures/primitives.pyi]
2828

2929
[case testNarrowingDownFromPromoteTargetType]

test-data/unit/fixtures/primitives.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ class bytes(Sequence[int]):
2929
def __iter__(self) -> Iterator[int]: pass
3030
def __contains__(self, other: object) -> bool: pass
3131
def __getitem__(self, item: int) -> int: pass
32-
class bytearray: pass
32+
class bytearray(Sequence[int]):
33+
def __init__(self, x: bytes) -> None: pass
34+
def __iter__(self) -> Iterator[int]: pass
35+
def __contains__(self, other: object) -> bool: pass
36+
def __getitem__(self, item: int) -> int: pass
3337
class tuple(Generic[T]): pass
3438
class function: pass
3539
class ellipsis: pass

0 commit comments

Comments
 (0)