diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 3f858c773b6f..5af8e0ef6f31 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -14,7 +14,7 @@ is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive, is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive, - is_int64_rprimitive, is_bit_rprimitive + is_int64_rprimitive, is_bit_rprimitive, is_bytes_rprimitive ) from mypyc.ir.func_ir import FuncDecl from mypyc.ir.class_ir import ClassIR, all_concrete_classes @@ -399,7 +399,7 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False, # TODO: Verify refcount handling. if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ) or is_float_rprimitive(typ) or is_str_rprimitive(typ) or is_int_rprimitive(typ) - or is_bool_rprimitive(typ)): + or is_bool_rprimitive(typ) or is_bytes_rprimitive(typ)): if declare_dest: self.emit_line('PyObject *{};'.format(dest)) if is_list_rprimitive(typ): @@ -416,6 +416,8 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False, prefix = 'PyLong' elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ): prefix = 'PyBool' + elif is_bytes_rprimitive(typ): + prefix = 'PyBytes' else: assert False, 'unexpected primitive type' check = '({}_Check({}))' diff --git a/mypyc/doc/dev-intro.md b/mypyc/doc/dev-intro.md index 1e14d00645db..5b4da5776633 100644 --- a/mypyc/doc/dev-intro.md +++ b/mypyc/doc/dev-intro.md @@ -487,8 +487,12 @@ operations, and so on. You likely also want to add some faster, specialized primitive operations for the type (see Adding a Specialized Primitive Operation above for how to do this). -Add a test case to `mypyc/test-data/run.test` to test compilation and -running compiled code. Ideas for things to test: +Add a test case file `mypyc/test-data/run-.test` to test +compilation and running compiled code. Update `mypyc/test/test_run.py` +to include the new file. You may need to also add something to +`mypyc/test-data/fixtures/ir.py` (stubs used by tests). + +Ideas for things to test: * Test using the type as an argument. diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 3e6ec79d131f..ba30ce1a6585 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -294,6 +294,10 @@ def __repr__(self) -> str: # (PyUnicode). str_rprimitive = RPrimitive('builtins.str', is_unboxed=False, is_refcounted=True) # type: Final +# Python bytes object. +bytes_rprimitive = RPrimitive('builtins.bytes', is_unboxed=False, + is_refcounted=True) # type: Final + # Tuple of an arbitrary length (corresponds to Tuple[t, ...], with # explicit '...'). tuple_rprimitive = RPrimitive('builtins.tuple', is_unboxed=False, @@ -364,6 +368,10 @@ def is_str_rprimitive(rtype: RType) -> bool: return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.str' +def is_bytes_rprimitive(rtype: RType) -> bool: + return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bytes' + + def is_tuple_rprimitive(rtype: RType) -> bool: return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple' diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 93c70e46038c..1697b40b1acb 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -30,7 +30,7 @@ c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive, is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject, none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive, - pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive + pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive, bytes_rprimitive ) from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.class_ir import ClassIR, all_concrete_classes @@ -467,7 +467,7 @@ def load_static_float(self, value: float) -> Value: def load_static_bytes(self, value: bytes) -> Value: """Loads a static bytes value into a register.""" identifier = self.literal_static_name(value) - return self.add(LoadGlobal(object_rprimitive, identifier, ann=value)) + return self.add(LoadGlobal(bytes_rprimitive, identifier, ann=value)) def load_static_complex(self, value: complex) -> Value: """Loads a static complex value into a register.""" diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 364e650aa5dc..7fe48fd236d1 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -14,7 +14,7 @@ from mypyc.ir.rtypes import ( RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive, none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive, - list_rprimitive, set_rprimitive + list_rprimitive, set_rprimitive, bytes_rprimitive ) from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg from mypyc.ir.class_ir import ClassIR @@ -57,6 +57,8 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: return bool_rprimitive elif typ.type.fullname == 'builtins.list': return list_rprimitive + elif typ.type.fullname == 'builtins.bytes': + return bytes_rprimitive # Dict subclasses are at least somewhat common and we # specifically support them, so make sure that dict operations # get optimized on them. diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py new file mode 100644 index 000000000000..5981284c9484 --- /dev/null +++ b/mypyc/primitives/bytes_ops.py @@ -0,0 +1,11 @@ +"""Primitive bytes ops.""" + +from mypyc.ir.rtypes import object_rprimitive +from mypyc.primitives.registry import load_address_op + + +# Get the 'bytes' type object. +load_address_op( + name='builtins.bytes', + type=object_rprimitive, + src='PyBytes_Type') diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 454e7f1f6db4..25e807b563bb 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -341,6 +341,7 @@ def load_address_op(name: str, # Import various modules that set up global state. import mypyc.primitives.int_ops # noqa import mypyc.primitives.str_ops # noqa +import mypyc.primitives.bytes_ops # noqa import mypyc.primitives.list_ops # noqa import mypyc.primitives.dict_ops # noqa import mypyc.primitives.tuple_ops # noqa diff --git a/mypyc/test-data/driver/driver.py b/mypyc/test-data/driver/driver.py index 4db9843358f1..3268d8b7cd16 100644 --- a/mypyc/test-data/driver/driver.py +++ b/mypyc/test-data/driver/driver.py @@ -26,7 +26,7 @@ def extract_line(tb): formatted = '\n'.join(format_tb(tb)) - m = re.search('File "native.py", line ([0-9]+), in test_', formatted) + m = re.search('File "(native|driver).py", line ([0-9]+), in (test_|)', formatted) return m.group(1) # Sort failures by line number of test function. diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 4ffefb7432de..f637ce08c437 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -91,10 +91,10 @@ def __mul__(self, n: complex) -> complex: pass def __truediv__(self, n: complex) -> complex: pass class bytes: - def __init__(self, x: object) -> None: pass - def __add__(self, x: object) -> bytes: pass - def __eq__(self, x:object) -> bool:pass + def __add__(self, x: bytes) -> bytes: pass + def __eq__(self, x: object) -> bool:pass def __ne__(self, x: object) -> bool: pass + def __getitem__(self, i: int) -> int: pass def join(self, x: Iterable[object]) -> bytes: pass class bool(int): diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 0da337ce2d49..ab63eba6f78d 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -769,7 +769,7 @@ def f() -> bytes: return b'1234' [out] def f(): - r0, x, r1 :: object + r0, x, r1 :: bytes L0: r0 = load_global CPyStatic_bytes_1 :: static (b'\xf0') x = r0 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test new file mode 100644 index 000000000000..51f16fe85519 --- /dev/null +++ b/mypyc/test-data/run-bytes.test @@ -0,0 +1,55 @@ +# Bytes test cases (compile and run) + +[case testBytesBasics] +# Note: Add tests for additional operations to testBytesOps or in a new test case + +def f(x: bytes) -> bytes: + return x + +def eq(a: bytes, b: bytes) -> bool: + return a == b + +def neq(a: bytes, b: bytes) -> bool: + return a != b +[file driver.py] +from native import f, eq, neq +assert f(b'123') == b'123' +assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' +assert eq(b'123', b'123') +assert not eq(b'123', b'1234') +assert neq(b'123', b'1234') +try: + f('x') + assert False +except TypeError: + pass + +[case testBytesOps] +def test_indexing() -> None: + # Use bytes() to avoid constant folding + b = b'asdf' + bytes() + assert b[0] == 97 + assert b[1] == 115 + assert b[3] == 102 + assert b[-1] == 102 + b = b'\xfe\x15' + bytes() + assert b[0] == 254 + assert b[1] == 21 + +def test_concat() -> None: + b1 = b'123' + bytes() + b2 = b'456' + bytes() + assert b1 + b2 == b'123456' + +def test_join() -> None: + seq = (b'1', b'"', b'\xf0') + assert b'\x07'.join(seq) == b'1\x07"\x07\xf0' + assert b', '.join(()) == b'' + assert b', '.join([bytes() + b'ab']) == b'ab' + assert b', '.join([bytes() + b'ab', b'cd']) == b'ab, cd' + +def test_len() -> None: + # Use bytes() to avoid constant folding + b = b'foo' + bytes() + assert len(b) == 3 + assert len(bytes()) == 0 diff --git a/mypyc/test-data/run-primitives.test b/mypyc/test-data/run-primitives.test index 450480d3f0a6..37a92a15e0fc 100644 --- a/mypyc/test-data/run-primitives.test +++ b/mypyc/test-data/run-primitives.test @@ -255,32 +255,6 @@ assert str(to_int(3.14)) == '3' assert str(to_int(3)) == '3' assert get_complex() == 3+5j -[case testBytes] -def f(x: bytes) -> bytes: - return x - -def concat(a: bytes, b: bytes) -> bytes: - return a + b - -def eq(a: bytes, b: bytes) -> bool: - return a == b - -def neq(a: bytes, b: bytes) -> bool: - return a != b - -def join() -> bytes: - seq = (b'1', b'"', b'\xf0') - return b'\x07'.join(seq) -[file driver.py] -from native import f, concat, eq, neq, join -assert f(b'123') == b'123' -assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' -assert concat(b'123', b'456') == b'123456' -assert eq(b'123', b'123') -assert not eq(b'123', b'1234') -assert neq(b'123', b'1234') -assert join() == b'1\x07"\x07\xf0' - [case testDel] from typing import List from testutil import assertRaises diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 82a288e0d293..a080339dae8e 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -35,6 +35,7 @@ 'run-integers.test', 'run-bools.test', 'run-strings.test', + 'run-bytes.test', 'run-tuples.test', 'run-lists.test', 'run-dicts.test',