diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 0ef6dd57e441..323d34f83a04 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -451,8 +451,8 @@ def emit_cast(self, # TODO: Verify refcount handling. if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ) - or is_str_rprimitive(typ) or is_bytes_rprimitive(typ) or is_range_rprimitive(typ) - or is_float_rprimitive(typ) or is_int_rprimitive(typ) or is_bool_rprimitive(typ)): + or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ) + or is_int_rprimitive(typ) or is_bool_rprimitive(typ) or is_bit_rprimitive(typ)): if declare_dest: self.emit_line('PyObject *{};'.format(dest)) if is_list_rprimitive(typ): @@ -463,8 +463,6 @@ def emit_cast(self, prefix = 'PySet' elif is_str_rprimitive(typ): prefix = 'PyUnicode' - elif is_bytes_rprimitive(typ): - prefix = 'PyBytes' elif is_range_rprimitive(typ): prefix = 'PyRange' elif is_float_rprimitive(typ): @@ -484,6 +482,18 @@ def emit_cast(self, 'else {', err, '}') + elif is_bytes_rprimitive(typ): + if declare_dest: + self.emit_line('PyObject *{};'.format(dest)) + check = '(PyBytes_Check({}) || PyByteArray_Check({}))' + if likely: + check = '(likely{})'.format(check) + self.emit_arg_check(src, dest, typ, check.format(src, src), optional) + self.emit_lines( + ' {} = {};'.format(dest, src), + 'else {', + err, + '}') elif is_tuple_rprimitive(typ): if declare_dest: self.emit_line('{} {};'.format(self.ctype(typ), dest)) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 9e32a70dfd79..0384f4831702 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -106,6 +106,14 @@ def __ne__(self, x: object) -> bool: pass def __getitem__(self, i: int) -> int: pass def join(self, x: Iterable[object]) -> bytes: pass +class bytearray: + @overload + def __init__(self) -> None: pass + @overload + def __init__(self, x: object) -> None: pass + @overload + def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass + class bool(int): def __init__(self, o: object = ...) -> None: ... @overload diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index b88201fc896a..a1bf3de599d4 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -52,4 +52,30 @@ def test_len() -> None: # Use bytes() to avoid constant folding b = b'foo' + bytes() assert len(b) == 3 - assert len(bytes()) == 0 \ No newline at end of file + assert len(bytes()) == 0 + +[case testBytearrayBasics] +from typing import Any + +def test_init() -> None: + brr1: bytes = bytearray(3) + assert brr1 == bytearray(b'\x00\x00\x00') + assert brr1 == b'\x00\x00\x00' + l = [10, 20, 30, 40] + brr2: bytes = bytearray(l) + assert brr2 == bytearray(b'\n\x14\x1e(') + assert brr2 == b'\n\x14\x1e(' + brr3: bytes = bytearray(range(5)) + assert brr3 == bytearray(b'\x00\x01\x02\x03\x04') + assert brr3 == b'\x00\x01\x02\x03\x04' + brr4: bytes = bytearray('string', 'utf-8') + assert brr4 == bytearray(b'string') + assert brr4 == b'string' + +def f(b: bytes) -> bool: + return True + +def test_bytearray_passed_into_bytes() -> None: + assert f(bytearray(3)) + brr1: Any = bytearray() + assert f(brr1)