Skip to content

Commit 26c270c

Browse files
authored
Merge pull request #57 from peytondmurray/ufunc-add
Add an add ufunc
2 parents d98210c + 19c3f87 commit 26c270c

File tree

2 files changed

+76
-26
lines changed

2 files changed

+76
-26
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,41 @@
1414
#include "umath.h"
1515

1616
static int
17-
minmax_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
19-
PyArray_Descr *given_descrs[],
20-
PyArray_Descr *loop_descrs[],
21-
npy_intp *NPY_UNUSED(view_offset))
17+
add_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
18+
char *const data[], npy_intp const dimensions[],
19+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
2220
{
23-
Py_INCREF(given_descrs[0]);
24-
loop_descrs[0] = given_descrs[0];
25-
Py_INCREF(given_descrs[1]);
26-
loop_descrs[1] = given_descrs[1];
21+
npy_intp N = dimensions[0];
22+
char *in1 = data[0];
23+
char *in2 = data[1];
24+
char *out = data[2];
25+
npy_intp in1_stride = strides[0];
26+
npy_intp in2_stride = strides[1];
27+
npy_intp out_stride = strides[2];
2728

28-
StringDTypeObject *new = new_stringdtype_instance();
29-
if (new == NULL) {
30-
return -1;
31-
}
32-
loop_descrs[2] = (PyArray_Descr *)new;
29+
ss *s1 = NULL, *s2 = NULL, *os = NULL;
30+
int newlen = 0;
3331

34-
return NPY_NO_CASTING;
32+
while (N--) {
33+
s1 = (ss *)in1;
34+
s2 = (ss *)in2;
35+
os = (ss *)out;
36+
newlen = s1->len + s2->len;
37+
38+
ssfree(os);
39+
if (ssnewemptylen(newlen, os) < 0) {
40+
return -1;
41+
}
42+
43+
memcpy(os->buf, s1->buf, s1->len);
44+
memcpy(os->buf + s1->len, s2->buf, s2->len);
45+
os->buf[newlen] = '\0';
46+
47+
in1 += in1_stride;
48+
in2 += in2_stride;
49+
out += out_stride;
50+
}
51+
return 0;
3552
}
3653

3754
static int
@@ -259,6 +276,9 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
259276
return 0;
260277
}
261278

279+
// Register a ufunc.
280+
//
281+
// Pass NULL for resolve_func to use the default_resolve_descriptors.
262282
int
263283
init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
264284
resolve_descriptors_function *resolve_func,
@@ -270,20 +290,26 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
270290
return -1;
271291
}
272292

273-
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
274-
{NPY_METH_strided_loop, loop_func},
275-
{0, NULL}};
276-
277293
PyArrayMethod_Spec spec = {
278294
.name = loop_name,
279295
.nin = nin,
280296
.nout = nout,
281297
.casting = casting,
282298
.flags = flags,
283299
.dtypes = dtypes,
284-
.slots = slots,
285300
};
286301

302+
if (resolve_func == NULL) {
303+
PyType_Slot slots[] = {{NPY_METH_strided_loop, loop_func}, {0, NULL}};
304+
spec.slots = slots;
305+
}
306+
else {
307+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
308+
{NPY_METH_strided_loop, loop_func},
309+
{0, NULL}};
310+
spec.slots = slots;
311+
}
312+
287313
if (PyUFunc_AddLoopFromSpec(ufunc, &spec) < 0) {
288314
Py_DECREF(ufunc);
289315
return -1;
@@ -367,14 +393,21 @@ init_ufuncs(void)
367393

368394
PyArray_DTypeMeta *minmax_dtypes[] = {&StringDType, &StringDType,
369395
&StringDType};
370-
if (init_ufunc(numpy, "maximum", minmax_dtypes,
371-
&minmax_resolve_descriptors, &maximum_strided_loop,
372-
"string_maximum", 2, 1, NPY_NO_CASTING, 0) < 0) {
396+
if (init_ufunc(numpy, "maximum", minmax_dtypes, NULL,
397+
&maximum_strided_loop, "string_maximum", 2, 1,
398+
NPY_NO_CASTING, 0) < 0) {
373399
goto error;
374400
}
375-
if (init_ufunc(numpy, "minimum", minmax_dtypes,
376-
&minmax_resolve_descriptors, &minimum_strided_loop,
377-
"string_minimum", 2, 1, NPY_NO_CASTING, 0) < 0) {
401+
if (init_ufunc(numpy, "minimum", minmax_dtypes, NULL,
402+
&minimum_strided_loop, "string_minimum", 2, 1,
403+
NPY_NO_CASTING, 0) < 0) {
404+
goto error;
405+
}
406+
407+
PyArray_DTypeMeta *add_types[] = {&StringDType, &StringDType,
408+
&StringDType};
409+
if (init_ufunc(numpy, "add", add_types, NULL, &add_strided_loop,
410+
"string_add", 2, 1, NPY_NO_CASTING, 0) < 0) {
378411
goto error;
379412
}
380413

stringdtype/tests/test_stringdtype.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,20 @@ def test_ufuncs_minmax(string_list, ufunc, func):
299299
np.testing.assert_array_equal(
300300
getattr(arr, ufunc)(), np.array(func(string_list), dtype=StringDType())
301301
)
302+
303+
304+
@pytest.mark.parametrize(
305+
"other_strings",
306+
[
307+
["abc", "def", "ghi", "🤣", "📵", "😰"],
308+
["🚜", "🙃", "😾", "😹", "🚠", "🚌"],
309+
["🥦", "¨", "⨯", "∰ ", "⨌ ", "⎶ "],
310+
],
311+
)
312+
def test_ufunc_add(string_list, other_strings):
313+
arr1 = np.array(string_list, dtype=StringDType())
314+
arr2 = np.array(other_strings, dtype=StringDType())
315+
np.testing.assert_array_equal(
316+
np.add(arr1, arr2),
317+
np.array([a + b for a, b in zip(arr1, arr2)], dtype=StringDType()),
318+
)

0 commit comments

Comments
 (0)