Skip to content

Commit 702e2b3

Browse files
committed
Add an add ufunc
1 parent 718fea9 commit 702e2b3

File tree

2 files changed

+77
-10
lines changed

2 files changed

+77
-10
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,59 @@
1313
#include "string.h"
1414
#include "umath.h"
1515

16+
static int
17+
add_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
18+
char *const data[], npy_intp const dimensions[],
19+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
20+
{
21+
npy_intp N = dimensions[0];
22+
char *in1 = data[0];
23+
char *in2 = data[1];
24+
char *out = data[2];
25+
npy_intp in1_stride = strides[0];
26+
npy_intp in2_stride = strides[1];
27+
npy_intp out_stride = strides[2];
28+
29+
ss *s1 = NULL, *s2 = NULL, *os = NULL;
30+
int newlen = 0;
31+
32+
while (N--) {
33+
s1 = (ss *)in1;
34+
s2 = (ss *)in2;
35+
os = (ss *)out;
36+
newlen = s1->len + s2->len;
37+
38+
ssfree(os);
39+
if (ssnewemptylen(newlen, os) < 0) {
40+
return -1;
41+
}
42+
43+
memcpy(os->buf, s1->buf, s1->len);
44+
memcpy(os->buf + s1->len, s2->buf, s2->len);
45+
os->buf[newlen] = '\0';
46+
47+
in1 += in1_stride;
48+
in2 += in2_stride;
49+
out += out_stride;
50+
}
51+
return 0;
52+
}
53+
1654
static int
1755
minmax_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
1856
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
1957
PyArray_Descr *given_descrs[],
2058
PyArray_Descr *loop_descrs[],
2159
npy_intp *NPY_UNUSED(view_offset))
2260
{
23-
Py_INCREF(given_descrs[0]);
24-
loop_descrs[0] = given_descrs[0];
25-
Py_INCREF(given_descrs[1]);
26-
loop_descrs[1] = given_descrs[1];
27-
2861
StringDTypeObject *new = new_stringdtype_instance();
2962
if (new == NULL) {
3063
return -1;
3164
}
65+
Py_INCREF(given_descrs[0]);
66+
loop_descrs[0] = given_descrs[0];
67+
Py_INCREF(given_descrs[1]);
68+
loop_descrs[1] = given_descrs[1];
3269
loop_descrs[2] = (PyArray_Descr *)new;
3370

3471
return NPY_NO_CASTING;
@@ -270,20 +307,26 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
270307
return -1;
271308
}
272309

273-
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
274-
{NPY_METH_strided_loop, loop_func},
275-
{0, NULL}};
276-
277310
PyArrayMethod_Spec spec = {
278311
.name = loop_name,
279312
.nin = nin,
280313
.nout = nout,
281314
.casting = casting,
282315
.flags = flags,
283316
.dtypes = dtypes,
284-
.slots = slots,
285317
};
286318

319+
if (resolve_func == NULL) {
320+
PyType_Slot slots[] = {{NPY_METH_strided_loop, loop_func}, {0, NULL}};
321+
spec.slots = slots;
322+
}
323+
else {
324+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
325+
{NPY_METH_strided_loop, loop_func},
326+
{0, NULL}};
327+
spec.slots = slots;
328+
}
329+
287330
if (PyUFunc_AddLoopFromSpec(ufunc, &spec) < 0) {
288331
Py_DECREF(ufunc);
289332
return -1;
@@ -378,6 +421,13 @@ init_ufuncs(void)
378421
goto error;
379422
}
380423

424+
PyArray_DTypeMeta *add_types[] = {&StringDType, &StringDType,
425+
&StringDType};
426+
if (init_ufunc(numpy, "add", add_types, NULL, &add_strided_loop,
427+
"string_add", 2, 1, NPY_NO_CASTING, 0) < 0) {
428+
goto error;
429+
}
430+
381431
Py_DECREF(numpy);
382432
return 0;
383433

stringdtype/tests/test_stringdtype.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,20 @@ def test_ufuncs_minmax(string_list, ufunc, func):
299299
np.testing.assert_array_equal(
300300
getattr(arr, ufunc)(), np.array(func(string_list), dtype=StringDType())
301301
)
302+
303+
304+
@pytest.mark.parametrize(
305+
"other_strings",
306+
[
307+
["abc", "def", "ghi", "🤣", "📵", "😰"],
308+
["🚜", "🙃", "😾", "😹", "🚠", "🚌"],
309+
["🥦", "¨", "⨯", "∰ ", "⨌ ", "⎶ "],
310+
],
311+
)
312+
def test_ufunc_add(string_list, other_strings):
313+
arr1 = np.array(string_list, dtype=StringDType())
314+
arr2 = np.array(other_strings, dtype=StringDType())
315+
np.testing.assert_array_equal(
316+
np.add(arr1, arr2),
317+
np.array([a + b for a, b in zip(arr1, arr2)], dtype=StringDType()),
318+
)

0 commit comments

Comments
 (0)