Skip to content

[mypyc] Add bytes primitive type #9611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive,
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive,
is_int64_rprimitive, is_bit_rprimitive
is_int64_rprimitive, is_bit_rprimitive, is_bytes_rprimitive
)
from mypyc.ir.func_ir import FuncDecl
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand Down Expand Up @@ -399,7 +399,7 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
# TODO: Verify refcount handling.
if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ)
or is_float_rprimitive(typ) or is_str_rprimitive(typ) or is_int_rprimitive(typ)
or is_bool_rprimitive(typ)):
or is_bool_rprimitive(typ) or is_bytes_rprimitive(typ)):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
if is_list_rprimitive(typ):
Expand All @@ -416,6 +416,8 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
prefix = 'PyLong'
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
prefix = 'PyBool'
elif is_bytes_rprimitive(typ):
prefix = 'PyBytes'
else:
assert False, 'unexpected primitive type'
check = '({}_Check({}))'
Expand Down
8 changes: 6 additions & 2 deletions mypyc/doc/dev-intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,12 @@ operations, and so on. You likely also want to add some faster,
specialized primitive operations for the type (see Adding a
Specialized Primitive Operation above for how to do this).

Add a test case to `mypyc/test-data/run.test` to test compilation and
running compiled code. Ideas for things to test:
Add a test case file `mypyc/test-data/run-<type>.test` to test
compilation and running compiled code. Update `mypyc/test/test_run.py`
to include the new file. You may need to also add something to
`mypyc/test-data/fixtures/ir.py` (stubs used by tests).

Ideas for things to test:

* Test using the type as an argument.

Expand Down
8 changes: 8 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def __repr__(self) -> str:
# (PyUnicode).
str_rprimitive = RPrimitive('builtins.str', is_unboxed=False, is_refcounted=True) # type: Final

# Python bytes object.
bytes_rprimitive = RPrimitive('builtins.bytes', is_unboxed=False,
is_refcounted=True) # type: Final

# Tuple of an arbitrary length (corresponds to Tuple[t, ...], with
# explicit '...').
tuple_rprimitive = RPrimitive('builtins.tuple', is_unboxed=False,
Expand Down Expand Up @@ -364,6 +368,10 @@ def is_str_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.str'


def is_bytes_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bytes'


def is_tuple_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple'

Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive,
is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject,
none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive,
pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive
pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive, bytes_rprimitive
)
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand Down Expand Up @@ -467,7 +467,7 @@ def load_static_float(self, value: float) -> Value:
def load_static_bytes(self, value: bytes) -> Value:
"""Loads a static bytes value into a register."""
identifier = self.literal_static_name(value)
return self.add(LoadGlobal(object_rprimitive, identifier, ann=value))
return self.add(LoadGlobal(bytes_rprimitive, identifier, ann=value))

def load_static_complex(self, value: complex) -> Value:
"""Loads a static complex value into a register."""
Expand Down
4 changes: 3 additions & 1 deletion mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypyc.ir.rtypes import (
RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive,
none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive,
list_rprimitive, set_rprimitive
list_rprimitive, set_rprimitive, bytes_rprimitive
)
from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg
from mypyc.ir.class_ir import ClassIR
Expand Down Expand Up @@ -57,6 +57,8 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType:
return bool_rprimitive
elif typ.type.fullname == 'builtins.list':
return list_rprimitive
elif typ.type.fullname == 'builtins.bytes':
return bytes_rprimitive
# Dict subclasses are at least somewhat common and we
# specifically support them, so make sure that dict operations
# get optimized on them.
Expand Down
11 changes: 11 additions & 0 deletions mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Primitive bytes ops."""

from mypyc.ir.rtypes import object_rprimitive
from mypyc.primitives.registry import load_address_op


# Get the 'bytes' type object.
load_address_op(
name='builtins.bytes',
type=object_rprimitive,
src='PyBytes_Type')
1 change: 1 addition & 0 deletions mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def load_address_op(name: str,
# Import various modules that set up global state.
import mypyc.primitives.int_ops # noqa
import mypyc.primitives.str_ops # noqa
import mypyc.primitives.bytes_ops # noqa
import mypyc.primitives.list_ops # noqa
import mypyc.primitives.dict_ops # noqa
import mypyc.primitives.tuple_ops # noqa
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def extract_line(tb):
formatted = '\n'.join(format_tb(tb))
m = re.search('File "native.py", line ([0-9]+), in test_', formatted)
m = re.search('File "(native|driver).py", line ([0-9]+), in (test_|<module>)', formatted)
return m.group(1)

# Sort failures by line number of test function.
Expand Down
6 changes: 3 additions & 3 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def __mul__(self, n: complex) -> complex: pass
def __truediv__(self, n: complex) -> complex: pass

class bytes:
def __init__(self, x: object) -> None: pass
def __add__(self, x: object) -> bytes: pass
def __eq__(self, x:object) -> bool:pass
def __add__(self, x: bytes) -> bytes: pass
def __eq__(self, x: object) -> bool:pass
def __ne__(self, x: object) -> bool: pass
def __getitem__(self, i: int) -> int: pass
def join(self, x: Iterable[object]) -> bytes: pass

class bool(int):
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def f() -> bytes:
return b'1234'
[out]
def f():
r0, x, r1 :: object
r0, x, r1 :: bytes
L0:
r0 = load_global CPyStatic_bytes_1 :: static (b'\xf0')
x = r0
Expand Down
55 changes: 55 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Bytes test cases (compile and run)

[case testBytesBasics]
# Note: Add tests for additional operations to testBytesOps or in a new test case

def f(x: bytes) -> bytes:
return x

def eq(a: bytes, b: bytes) -> bool:
return a == b

def neq(a: bytes, b: bytes) -> bool:
return a != b
[file driver.py]
from native import f, eq, neq
assert f(b'123') == b'123'
assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0'
assert eq(b'123', b'123')
assert not eq(b'123', b'1234')
assert neq(b'123', b'1234')
try:
f('x')
assert False
except TypeError:
pass

[case testBytesOps]
def test_indexing() -> None:
# Use bytes() to avoid constant folding
b = b'asdf' + bytes()
assert b[0] == 97
assert b[1] == 115
assert b[3] == 102
assert b[-1] == 102
b = b'\xfe\x15' + bytes()
assert b[0] == 254
assert b[1] == 21

def test_concat() -> None:
b1 = b'123' + bytes()
b2 = b'456' + bytes()
assert b1 + b2 == b'123456'

def test_join() -> None:
seq = (b'1', b'"', b'\xf0')
assert b'\x07'.join(seq) == b'1\x07"\x07\xf0'
assert b', '.join(()) == b''
assert b', '.join([bytes() + b'ab']) == b'ab'
assert b', '.join([bytes() + b'ab', b'cd']) == b'ab, cd'

def test_len() -> None:
# Use bytes() to avoid constant folding
b = b'foo' + bytes()
assert len(b) == 3
assert len(bytes()) == 0
26 changes: 0 additions & 26 deletions mypyc/test-data/run-primitives.test
Original file line number Diff line number Diff line change
Expand Up @@ -255,32 +255,6 @@ assert str(to_int(3.14)) == '3'
assert str(to_int(3)) == '3'
assert get_complex() == 3+5j

[case testBytes]
def f(x: bytes) -> bytes:
return x

def concat(a: bytes, b: bytes) -> bytes:
return a + b

def eq(a: bytes, b: bytes) -> bool:
return a == b

def neq(a: bytes, b: bytes) -> bool:
return a != b

def join() -> bytes:
seq = (b'1', b'"', b'\xf0')
return b'\x07'.join(seq)
[file driver.py]
from native import f, concat, eq, neq, join
assert f(b'123') == b'123'
assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0'
assert concat(b'123', b'456') == b'123456'
assert eq(b'123', b'123')
assert not eq(b'123', b'1234')
assert neq(b'123', b'1234')
assert join() == b'1\x07"\x07\xf0'

[case testDel]
from typing import List
from testutil import assertRaises
Expand Down
1 change: 1 addition & 0 deletions mypyc/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
'run-integers.test',
'run-bools.test',
'run-strings.test',
'run-bytes.test',
'run-tuples.test',
'run-lists.test',
'run-dicts.test',
Expand Down