Skip to content

Commit 08a45e0

Browse files
authored
[mypyc] Add a primitive for bytes "translate" method (#20305)
This only supports one-argument calls. In a microbenchmark the performance was significantly better for small inputs. It was also slightly better for long inputs, but only after I added a loop unrolling hint (on Apple Silicon and AMD + Linux). I relied on LLM assist as another experiment, but did it in small increments with detailed instructions.
1 parent 03b12a1 commit 08a45e0

File tree

7 files changed

+115
-0
lines changed

7 files changed

+115
-0
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
783783
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
784784
CPyTagged CPyBytes_Ord(PyObject *obj);
785785
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
786+
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);
786787

787788

788789
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,52 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
171171
}
172172
return PySequence_Repeat(bytes, temp_count);
173173
}
174+
175+
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) {
176+
// Fast path: exact bytes object with exact bytes table
177+
if (PyBytes_CheckExact(bytes) && PyBytes_CheckExact(table)) {
178+
Py_ssize_t table_len = PyBytes_GET_SIZE(table);
179+
if (table_len != 256) {
180+
PyErr_SetString(PyExc_ValueError,
181+
"translation table must be 256 characters long");
182+
return NULL;
183+
}
184+
185+
Py_ssize_t len = PyBytes_GET_SIZE(bytes);
186+
const char *input = PyBytes_AS_STRING(bytes);
187+
const char *trans_table = PyBytes_AS_STRING(table);
188+
189+
PyObject *result = PyBytes_FromStringAndSize(NULL, len);
190+
if (result == NULL) {
191+
return NULL;
192+
}
193+
194+
char *output = PyBytes_AS_STRING(result);
195+
bool changed = false;
196+
197+
// Without a loop unrolling hint performance can be worse than CPython
198+
CPY_UNROLL_LOOP(4)
199+
for (Py_ssize_t i = len; --i >= 0;) {
200+
char c = *input++;
201+
if ((*output++ = trans_table[(unsigned char)c]) != c)
202+
changed = true;
203+
}
204+
205+
// If nothing changed, discard result and return the original object
206+
if (!changed) {
207+
Py_DECREF(result);
208+
Py_INCREF(bytes);
209+
return bytes;
210+
}
211+
212+
return result;
213+
}
214+
215+
// Fallback to Python method call for non-exact types or non-standard tables
216+
_Py_IDENTIFIER(translate);
217+
PyObject *name = _PyUnicode_FromId(&PyId_translate);
218+
if (name == NULL) {
219+
return NULL;
220+
}
221+
return PyObject_CallMethodOneArg(bytes, name, table);
222+
}

mypyc/lib-rt/mypyc_util.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@
4848

4949
#endif // Py_GIL_DISABLED
5050

51+
// Helper macro for stringification in _Pragma
52+
#define CPY_STRINGIFY(x) #x
53+
54+
#if defined(__clang__)
55+
#define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x))
56+
#define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(unroll n)
57+
#elif defined(__GNUC__) && __GNUC__ >= 8
58+
#define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x))
59+
#define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(GCC unroll n)
60+
#else
61+
#define CPY_UNROLL_LOOP(n)
62+
#endif
63+
5164
// INCREF and DECREF that assert the pointer is not NULL.
5265
// asserts are disabled in release builds so there shouldn't be a perf hit.
5366
// I'm honestly kind of surprised that this isn't done by default.

mypyc/primitives/bytes_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@
128128
error_kind=ERR_MAGIC,
129129
)
130130

131+
# bytes.translate(table)
132+
method_op(
133+
name="translate",
134+
arg_types=[bytes_rprimitive, object_rprimitive],
135+
return_type=bytes_rprimitive,
136+
c_function_name="CPyBytes_Translate",
137+
error_kind=ERR_MAGIC,
138+
)
139+
131140
# Join bytes objects and return a new bytes.
132141
# The first argument is the total number of the following bytes.
133142
bytes_build_op = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __getitem__(self, i: int) -> int: ...
178178
def __getitem__(self, i: slice) -> bytes: ...
179179
def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181+
def translate(self, t: bytes) -> bytes: ...
181182
def __iter__(self) -> Iterator[int]: ...
182183

183184
class bytearray:

mypyc/test-data/irbuild-bytes.test

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,13 @@ def i_times_b(s, n):
238238
L0:
239239
r0 = CPyBytes_Multiply(s, n)
240240
return r0
241+
242+
[case testBytesTranslate]
243+
def f(b: bytes, table: bytes) -> bytes:
244+
return b.translate(table)
245+
[out]
246+
def f(b, table):
247+
b, table, r0 :: bytes
248+
L0:
249+
r0 = CPyBytes_Translate(b, table)
250+
return r0

mypyc/test-data/run-bytes.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,38 @@ def test_multiply() -> None:
168168
result = b * two
169169
assert type(result) == bytes
170170

171+
def test_translate() -> None:
172+
# Identity translation table (fast path - exact bytes)
173+
identity = bytes(range(256))
174+
b = b'hello world' + bytes()
175+
assert b.translate(identity) == b'hello world'
176+
177+
# ROT13-like translation for lowercase letters
178+
table = bytearray(range(256))
179+
for i in range(ord('a'), ord('z') + 1):
180+
table[i] = ord('a') + (i - ord('a') + 13) % 26
181+
table_bytes = bytes(table)
182+
assert b'hello'.translate(table_bytes) == b'uryyb'
183+
assert (b'abc' + bytes()).translate(table_bytes) == b'nop'
184+
185+
# Test with special characters
186+
assert b'\x00\x01\xff'.translate(identity) == b'\x00\x01\xff'
187+
188+
# Test with bytearray table (slow path - fallback to Python method)
189+
bytearray_table = bytearray(range(256))
190+
assert b'hello'.translate(bytearray_table) == b'hello'
191+
# Modify bytearray table to uppercase
192+
for i in range(ord('a'), ord('z') + 1):
193+
bytearray_table[i] = ord('A') + (i - ord('a'))
194+
assert b'hello world'.translate(bytearray_table) == b'HELLO WORLD'
195+
assert (b'test' + bytes()).translate(bytearray_table) == b'TEST'
196+
197+
# Test error on wrong table size
198+
with assertRaises(ValueError, "translation table must be 256 characters long"):
199+
b'test'.translate(b'short')
200+
with assertRaises(ValueError, "translation table must be 256 characters long"):
201+
b'test'.translate(bytes(100))
202+
171203
[case testBytesSlicing]
172204
def test_bytes_slicing() -> None:
173205
b = b'abcdefg'

0 commit comments

Comments
 (0)