Skip to content

[mypyc] Add primitive for bytes join() method #10929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str);
// Bytes operations


PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);


// Set operations

Expand Down
10 changes: 10 additions & 0 deletions mypyc/lib-rt/bytes_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@

#include <Python.h>
#include "CPy.h"

// Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes
// (mostly commonly, for bytearrays)
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) {
if (PyBytes_CheckExact(sep)) {
return _PyBytes_Join(sep, iter);
} else {
return PyObject_CallMethod(sep, "join", "(O)", iter);
}
}
11 changes: 10 additions & 1 deletion mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive,
str_rprimitive, RUnion
)
from mypyc.primitives.registry import load_address_op, function_op
from mypyc.primitives.registry import load_address_op, function_op, method_op


# Get the 'bytes' type object.
Expand All @@ -29,3 +29,12 @@
return_type=bytes_rprimitive,
c_function_name='PyByteArray_FromObject',
error_kind=ERR_MAGIC)

# bytes.join(obj)
method_op(
name='join',
arg_types=[bytes_rprimitive, object_rprimitive],
return_type=bytes_rprimitive,
c_function_name='CPyBytes_Join',
error_kind=ERR_MAGIC
)
14 changes: 14 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,17 @@ L0:
c = r6
return 1


[case testBytesJoin]
from typing import List

def f(b: List[bytes]) -> bytes:
return b" ".join(b)
[out]
def f(b):
b :: list
r0, r1 :: bytes
L0:
r0 = b' '
r1 = CPyBytes_Join(r0, b)
return r1
25 changes: 25 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,28 @@ def test_bytearray_passed_into_bytes() -> None:
assert f(bytearray(3))
brr1: Any = bytearray()
assert f(brr1)

[case testBytesJoin]
from typing import Any
from testutil import assertRaises
from a import bytes_subclass

def test_bytes_join() -> None:
assert b' '.join([b'a', b'b']) == b'a b'
assert b' '.join([]) == b''

x: bytes = bytearray(b' ')
assert x.join([b'a', b'b']) == b'a b'
assert type(x.join([b'a', b'b'])) == bytearray

y: bytes = bytes_subclass()
assert y.join([]) == b'spook'

n: Any = 5
with assertRaises(TypeError, "can only join an iterable"):
assert b' '.join(n)

[file a.py]
class bytes_subclass(bytes):
def join(self, iter):
return b'spook'