diff --git a/stringdtype/stringdtype/src/umath.c b/stringdtype/stringdtype/src/umath.c index 5fa15064..30330962 100644 --- a/stringdtype/stringdtype/src/umath.c +++ b/stringdtype/stringdtype/src/umath.c @@ -14,24 +14,41 @@ #include "umath.h" static int -minmax_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method), - PyArray_DTypeMeta *NPY_UNUSED(dtypes[]), - PyArray_Descr *given_descrs[], - PyArray_Descr *loop_descrs[], - npy_intp *NPY_UNUSED(view_offset)) +add_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context), + char *const data[], npy_intp const dimensions[], + npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) { - Py_INCREF(given_descrs[0]); - loop_descrs[0] = given_descrs[0]; - Py_INCREF(given_descrs[1]); - loop_descrs[1] = given_descrs[1]; + npy_intp N = dimensions[0]; + char *in1 = data[0]; + char *in2 = data[1]; + char *out = data[2]; + npy_intp in1_stride = strides[0]; + npy_intp in2_stride = strides[1]; + npy_intp out_stride = strides[2]; - StringDTypeObject *new = new_stringdtype_instance(); - if (new == NULL) { - return -1; - } - loop_descrs[2] = (PyArray_Descr *)new; + ss *s1 = NULL, *s2 = NULL, *os = NULL; + int newlen = 0; - return NPY_NO_CASTING; + while (N--) { + s1 = (ss *)in1; + s2 = (ss *)in2; + os = (ss *)out; + newlen = s1->len + s2->len; + + ssfree(os); + if (ssnewemptylen(newlen, os) < 0) { + return -1; + } + + memcpy(os->buf, s1->buf, s1->len); + memcpy(os->buf + s1->len, s2->buf, s2->len); + os->buf[newlen] = '\0'; + + in1 += in1_stride; + in2 += in2_stride; + out += out_stride; + } + return 0; } static int @@ -259,6 +276,9 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[], return 0; } +// Register a ufunc. +// +// Pass NULL for resolve_func to use the default_resolve_descriptors. int init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes, resolve_descriptors_function *resolve_func, @@ -270,10 +290,6 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes, return -1; } - PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func}, - {NPY_METH_strided_loop, loop_func}, - {0, NULL}}; - PyArrayMethod_Spec spec = { .name = loop_name, .nin = nin, @@ -281,9 +297,19 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes, .casting = casting, .flags = flags, .dtypes = dtypes, - .slots = slots, }; + if (resolve_func == NULL) { + PyType_Slot slots[] = {{NPY_METH_strided_loop, loop_func}, {0, NULL}}; + spec.slots = slots; + } + else { + PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func}, + {NPY_METH_strided_loop, loop_func}, + {0, NULL}}; + spec.slots = slots; + } + if (PyUFunc_AddLoopFromSpec(ufunc, &spec) < 0) { Py_DECREF(ufunc); return -1; @@ -367,14 +393,21 @@ init_ufuncs(void) PyArray_DTypeMeta *minmax_dtypes[] = {&StringDType, &StringDType, &StringDType}; - if (init_ufunc(numpy, "maximum", minmax_dtypes, - &minmax_resolve_descriptors, &maximum_strided_loop, - "string_maximum", 2, 1, NPY_NO_CASTING, 0) < 0) { + if (init_ufunc(numpy, "maximum", minmax_dtypes, NULL, + &maximum_strided_loop, "string_maximum", 2, 1, + NPY_NO_CASTING, 0) < 0) { goto error; } - if (init_ufunc(numpy, "minimum", minmax_dtypes, - &minmax_resolve_descriptors, &minimum_strided_loop, - "string_minimum", 2, 1, NPY_NO_CASTING, 0) < 0) { + if (init_ufunc(numpy, "minimum", minmax_dtypes, NULL, + &minimum_strided_loop, "string_minimum", 2, 1, + NPY_NO_CASTING, 0) < 0) { + goto error; + } + + PyArray_DTypeMeta *add_types[] = {&StringDType, &StringDType, + &StringDType}; + if (init_ufunc(numpy, "add", add_types, NULL, &add_strided_loop, + "string_add", 2, 1, NPY_NO_CASTING, 0) < 0) { goto error; } diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index 3ab513f0..136320bc 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -299,3 +299,20 @@ def test_ufuncs_minmax(string_list, ufunc, func): np.testing.assert_array_equal( getattr(arr, ufunc)(), np.array(func(string_list), dtype=StringDType()) ) + + +@pytest.mark.parametrize( + "other_strings", + [ + ["abc", "def", "ghi", "๐Ÿคฃ", "๐Ÿ“ต", "๐Ÿ˜ฐ"], + ["๐Ÿšœ", "๐Ÿ™ƒ", "๐Ÿ˜พ", "๐Ÿ˜น", "๐Ÿš ", "๐ŸšŒ"], + ["๐Ÿฅฆ", "ยจ", "โจฏ", "โˆฐ ", "โจŒ ", "โŽถ "], + ], +) +def test_ufunc_add(string_list, other_strings): + arr1 = np.array(string_list, dtype=StringDType()) + arr2 = np.array(other_strings, dtype=StringDType()) + np.testing.assert_array_equal( + np.add(arr1, arr2), + np.array([a + b for a, b in zip(arr1, arr2)], dtype=StringDType()), + )