Skip to content

Commit 10dc451

Browse files
[mypyc] Add bytes primitive type (#10881)
Replaces #9611.
1 parent 9b10175 commit 10dc451

File tree

13 files changed

+119
-45
lines changed

13 files changed

+119
-45
lines changed

mypy/typeshed/stdlib/typing.pyi

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986
22
import sys
33
from abc import ABCMeta, abstractmethod
44
from types import BuiltinFunctionType, CodeType, FrameType, FunctionType, MethodType, ModuleType, TracebackType
5+
from typing_extensions import Literal as _Literal
56

67
if sys.version_info >= (3, 7):
78
from types import MethodDescriptorType, MethodWrapperType, WrapperDescriptorType
@@ -563,19 +564,35 @@ class Match(Generic[AnyStr]):
563564
# this match instance.
564565
re: Pattern[AnyStr]
565566
def expand(self, template: AnyStr) -> AnyStr: ...
566-
# TODO: The return for a group may be None, except if __group is 0 or not given.
567+
# group() returns "AnyStr" or "AnyStr | None", depending on the pattern.
567568
@overload
568-
def group(self, __group: Union[str, int] = ...) -> AnyStr: ...
569+
def group(self, __group: _Literal[0] = ...) -> AnyStr: ...
569570
@overload
570-
def group(self, __group1: Union[str, int], __group2: Union[str, int], *groups: Union[str, int]) -> Tuple[AnyStr, ...]: ...
571-
def groups(self, default: AnyStr = ...) -> Sequence[AnyStr]: ...
572-
def groupdict(self, default: AnyStr = ...) -> dict[str, AnyStr]: ...
571+
def group(self, __group: str | int) -> AnyStr | Any: ...
572+
@overload
573+
def group(self, __group1: str | int, __group2: str | int, *groups: str | int) -> Tuple[AnyStr | Any, ...]: ...
574+
# Each item of groups()'s return tuple is either "AnyStr" or
575+
# "AnyStr | None", depending on the pattern.
576+
@overload
577+
def groups(self) -> Tuple[AnyStr | Any, ...]: ...
578+
@overload
579+
def groups(self, default: _T) -> Tuple[AnyStr | _T, ...]: ...
580+
# Each value in groupdict()'s return dict is either "AnyStr" or
581+
# "AnyStr | None", depending on the pattern.
582+
@overload
583+
def groupdict(self) -> dict[str, AnyStr | Any]: ...
584+
@overload
585+
def groupdict(self, default: _T) -> dict[str, AnyStr | _T]: ...
573586
def start(self, __group: Union[int, str] = ...) -> int: ...
574587
def end(self, __group: Union[int, str] = ...) -> int: ...
575588
def span(self, __group: Union[int, str] = ...) -> Tuple[int, int]: ...
576589
@property
577590
def regs(self) -> Tuple[Tuple[int, int], ...]: ... # undocumented
578-
def __getitem__(self, g: Union[int, str]) -> AnyStr: ...
591+
# __getitem__() returns "AnyStr" or "AnyStr | None", depending on the pattern.
592+
@overload
593+
def __getitem__(self, __key: _Literal[0]) -> AnyStr: ...
594+
@overload
595+
def __getitem__(self, __key: int | str) -> AnyStr | Any: ...
579596
if sys.version_info >= (3, 9):
580597
def __class_getitem__(cls, item: Any) -> GenericAlias: ...
581598

mypyc/codegen/emit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
1616
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive,
1717
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive,
18-
is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive
18+
is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive, is_bytes_rprimitive
1919
)
2020
from mypyc.ir.func_ir import FuncDecl
2121
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
@@ -451,8 +451,8 @@ def emit_cast(self,
451451

452452
# TODO: Verify refcount handling.
453453
if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ)
454-
or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ)
455-
or is_int_rprimitive(typ) or is_bool_rprimitive(typ)):
454+
or is_str_rprimitive(typ) or is_bytes_rprimitive(typ) or is_range_rprimitive(typ)
455+
or is_float_rprimitive(typ) or is_int_rprimitive(typ) or is_bool_rprimitive(typ)):
456456
if declare_dest:
457457
self.emit_line('PyObject *{};'.format(dest))
458458
if is_list_rprimitive(typ):
@@ -463,6 +463,8 @@ def emit_cast(self,
463463
prefix = 'PySet'
464464
elif is_str_rprimitive(typ):
465465
prefix = 'PyUnicode'
466+
elif is_bytes_rprimitive(typ):
467+
prefix = 'PyBytes'
466468
elif is_range_rprimitive(typ):
467469
prefix = 'PyRange'
468470
elif is_float_rprimitive(typ):

mypyc/ir/rtypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def __hash__(self) -> int:
328328
# (PyUnicode).
329329
str_rprimitive: Final = RPrimitive("builtins.str", is_unboxed=False, is_refcounted=True)
330330

331+
# Python bytes object.
332+
bytes_rprimitive: Final = RPrimitive('builtins.bytes', is_unboxed=False, is_refcounted=True)
333+
331334
# Tuple of an arbitrary length (corresponds to Tuple[t, ...], with
332335
# explicit '...').
333336
tuple_rprimitive: Final = RPrimitive("builtins.tuple", is_unboxed=False, is_refcounted=True)
@@ -410,6 +413,10 @@ def is_str_rprimitive(rtype: RType) -> bool:
410413
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.str'
411414

412415

416+
def is_bytes_rprimitive(rtype: RType) -> bool:
417+
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bytes'
418+
419+
413420
def is_tuple_rprimitive(rtype: RType) -> bool:
414421
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple'
415422

mypyc/irbuild/ll_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive,
3333
is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject,
3434
none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive,
35-
pointer_rprimitive, PyObject, bit_rprimitive, is_bit_rprimitive,
36-
object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, PyListObject
35+
pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive,
36+
object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, bytes_rprimitive
3737
)
3838
from mypyc.ir.func_ir import FuncDecl, FuncSignature
3939
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
@@ -802,7 +802,7 @@ def load_str(self, value: str) -> Value:
802802

803803
def load_bytes(self, value: bytes) -> Value:
804804
"""Load a bytes literal value."""
805-
return self.add(LoadLiteral(value, object_rprimitive))
805+
return self.add(LoadLiteral(value, bytes_rprimitive))
806806

807807
def load_complex(self, value: complex) -> Value:
808808
"""Load a complex literal value."""

mypyc/irbuild/mapper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mypyc.ir.rtypes import (
1313
RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive,
1414
none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive,
15-
list_rprimitive, set_rprimitive, range_rprimitive
15+
list_rprimitive, set_rprimitive, range_rprimitive, bytes_rprimitive
1616
)
1717
from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg
1818
from mypyc.ir.class_ir import ClassIR
@@ -43,10 +43,12 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType:
4343
return int_rprimitive
4444
elif typ.type.fullname == 'builtins.float':
4545
return float_rprimitive
46-
elif typ.type.fullname == 'builtins.str':
47-
return str_rprimitive
4846
elif typ.type.fullname == 'builtins.bool':
4947
return bool_rprimitive
48+
elif typ.type.fullname == 'builtins.str':
49+
return str_rprimitive
50+
elif typ.type.fullname == 'builtins.bytes':
51+
return bytes_rprimitive
5052
elif typ.type.fullname == 'builtins.list':
5153
return list_rprimitive
5254
# Dict subclasses are at least somewhat common and we

mypyc/primitives/bytes_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Primitive bytes ops."""
2+
3+
from mypyc.ir.rtypes import object_rprimitive
4+
from mypyc.primitives.registry import load_address_op
5+
6+
7+
# Get the 'bytes' type object.
8+
load_address_op(
9+
name='builtins.bytes',
10+
type=object_rprimitive,
11+
src='PyBytes_Type')

mypyc/primitives/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def load_address_op(name: str,
239239
# Import various modules that set up global state.
240240
import mypyc.primitives.int_ops # noqa
241241
import mypyc.primitives.str_ops # noqa
242+
import mypyc.primitives.bytes_ops # noqa
242243
import mypyc.primitives.list_ops # noqa
243244
import mypyc.primitives.dict_ops # noqa
244245
import mypyc.primitives.tuple_ops # noqa

mypyc/test-data/driver/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
def extract_line(tb):
2828
formatted = '\n'.join(format_tb(tb))
29-
m = re.search('File "native.py", line ([0-9]+), in test_', formatted)
29+
m = re.search('File "(native|driver).py", line ([0-9]+), in (test_|<module>)', formatted)
3030
if m is None:
3131
return "0"
3232
return m.group(1)

mypyc/test-data/fixtures/ir.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ def __truediv__(self, n: complex) -> complex: pass
9696
def __neg__(self) -> complex: pass
9797

9898
class bytes:
99+
@overload
100+
def __init__(self) -> None: pass
101+
@overload
99102
def __init__(self, x: object) -> None: pass
100-
def __add__(self, x: object) -> bytes: pass
101-
def __eq__(self, x:object) -> bool:pass
103+
def __add__(self, x: bytes) -> bytes: pass
104+
def __eq__(self, x: object) -> bool: pass
102105
def __ne__(self, x: object) -> bool: pass
106+
def __getitem__(self, i: int) -> int: pass
103107
def join(self, x: Iterable[object]) -> bytes: pass
104108

105109
class bool(int):

mypyc/test-data/irbuild-basic.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def f() -> bytes:
761761
return b'1234'
762762
[out]
763763
def f():
764-
r0, x, r1 :: object
764+
r0, x, r1 :: bytes
765765
L0:
766766
r0 = b'\xf0'
767767
x = r0

mypyc/test-data/run-bytes.test

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Bytes test cases (compile and run)
2+
3+
[case testBytesBasics]
4+
# Note: Add tests for additional operations to testBytesOps or in a new test case
5+
6+
def f(x: bytes) -> bytes:
7+
return x
8+
9+
def eq(a: bytes, b: bytes) -> bool:
10+
return a == b
11+
12+
def neq(a: bytes, b: bytes) -> bool:
13+
return a != b
14+
[file driver.py]
15+
from native import f, eq, neq
16+
assert f(b'123') == b'123'
17+
assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0'
18+
assert eq(b'123', b'123')
19+
assert not eq(b'123', b'1234')
20+
assert neq(b'123', b'1234')
21+
try:
22+
f('x')
23+
assert False
24+
except TypeError:
25+
pass
26+
27+
[case testBytesOps]
28+
def test_indexing() -> None:
29+
# Use bytes() to avoid constant folding
30+
b = b'asdf' + bytes()
31+
assert b[0] == 97
32+
assert b[1] == 115
33+
assert b[3] == 102
34+
assert b[-1] == 102
35+
b = b'\xfe\x15' + bytes()
36+
assert b[0] == 254
37+
assert b[1] == 21
38+
39+
def test_concat() -> None:
40+
b1 = b'123' + bytes()
41+
b2 = b'456' + bytes()
42+
assert b1 + b2 == b'123456'
43+
44+
def test_join() -> None:
45+
seq = (b'1', b'"', b'\xf0')
46+
assert b'\x07'.join(seq) == b'1\x07"\x07\xf0'
47+
assert b', '.join(()) == b''
48+
assert b', '.join([bytes() + b'ab']) == b'ab'
49+
assert b', '.join([bytes() + b'ab', b'cd']) == b'ab, cd'
50+
51+
def test_len() -> None:
52+
# Use bytes() to avoid constant folding
53+
b = b'foo' + bytes()
54+
assert len(b) == 3
55+
assert len(bytes()) == 0

mypyc/test-data/run-primitives.test

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -255,32 +255,6 @@ assert str(to_int(3.14)) == '3'
255255
assert str(to_int(3)) == '3'
256256
assert get_complex() == 3.5 + 6.2j
257257

258-
[case testBytes]
259-
def f(x: bytes) -> bytes:
260-
return x
261-
262-
def concat(a: bytes, b: bytes) -> bytes:
263-
return a + b
264-
265-
def eq(a: bytes, b: bytes) -> bool:
266-
return a == b
267-
268-
def neq(a: bytes, b: bytes) -> bool:
269-
return a != b
270-
271-
def join() -> bytes:
272-
seq = (b'1', b'"', b'\xf0')
273-
return b'\x07'.join(seq)
274-
[file driver.py]
275-
from native import f, concat, eq, neq, join
276-
assert f(b'123') == b'123'
277-
assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0'
278-
assert concat(b'123', b'456') == b'123456'
279-
assert eq(b'123', b'123')
280-
assert not eq(b'123', b'1234')
281-
assert neq(b'123', b'1234')
282-
assert join() == b'1\x07"\x07\xf0'
283-
284258
[case testDel]
285259
from typing import List
286260
from testutil import assertRaises

mypyc/test/test_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
'run-floats.test',
3737
'run-bools.test',
3838
'run-strings.test',
39+
'run-bytes.test',
3940
'run-tuples.test',
4041
'run-lists.test',
4142
'run-dicts.test',

0 commit comments

Comments
 (0)