Skip to content

Commit 97a1b3f

Browse files
authored
Add primitive for bytes join() method (#10929)
1 parent e734321 commit 97a1b3f

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str);
400400
// Bytes operations
401401

402402

403+
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
404+
403405

404406
// Set operations
405407

mypyc/lib-rt/bytes_ops.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,13 @@
44

55
#include <Python.h>
66
#include "CPy.h"
7+
8+
// Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes
9+
// (mostly commonly, for bytearrays)
10+
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) {
11+
if (PyBytes_CheckExact(sep)) {
12+
return _PyBytes_Join(sep, iter);
13+
} else {
14+
return PyObject_CallMethod(sep, "join", "(O)", iter);
15+
}
16+
}

mypyc/primitives/bytes_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive,
66
str_rprimitive, RUnion
77
)
8-
from mypyc.primitives.registry import load_address_op, function_op
8+
from mypyc.primitives.registry import load_address_op, function_op, method_op
99

1010

1111
# Get the 'bytes' type object.
@@ -29,3 +29,12 @@
2929
return_type=bytes_rprimitive,
3030
c_function_name='PyByteArray_FromObject',
3131
error_kind=ERR_MAGIC)
32+
33+
# bytes.join(obj)
34+
method_op(
35+
name='join',
36+
arg_types=[bytes_rprimitive, object_rprimitive],
37+
return_type=bytes_rprimitive,
38+
c_function_name='CPyBytes_Join',
39+
error_kind=ERR_MAGIC
40+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,17 @@ L0:
6262
c = r6
6363
return 1
6464

65+
66+
[case testBytesJoin]
67+
from typing import List
68+
69+
def f(b: List[bytes]) -> bytes:
70+
return b" ".join(b)
71+
[out]
72+
def f(b):
73+
b :: list
74+
r0, r1 :: bytes
75+
L0:
76+
r0 = b' '
77+
r1 = CPyBytes_Join(r0, b)
78+
return r1

mypyc/test-data/run-bytes.test

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,28 @@ def test_bytearray_passed_into_bytes() -> None:
9797
assert f(bytearray(3))
9898
brr1: Any = bytearray()
9999
assert f(brr1)
100+
101+
[case testBytesJoin]
102+
from typing import Any
103+
from testutil import assertRaises
104+
from a import bytes_subclass
105+
106+
def test_bytes_join() -> None:
107+
assert b' '.join([b'a', b'b']) == b'a b'
108+
assert b' '.join([]) == b''
109+
110+
x: bytes = bytearray(b' ')
111+
assert x.join([b'a', b'b']) == b'a b'
112+
assert type(x.join([b'a', b'b'])) == bytearray
113+
114+
y: bytes = bytes_subclass()
115+
assert y.join([]) == b'spook'
116+
117+
n: Any = 5
118+
with assertRaises(TypeError, "can only join an iterable"):
119+
assert b' '.join(n)
120+
121+
[file a.py]
122+
class bytes_subclass(bytes):
123+
def join(self, iter):
124+
return b'spook'

0 commit comments

Comments
 (0)