diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d59407c..30c99cd4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: language: system require_serial: true entry: | - bash -c 'cd metadatadtype && mkdir -p build && pip install build meson-python patchelf wheel && python -m build --wheel --no-isolation -Cbuilddir=build'; + bash -c 'cd metadatadtype && mkdir -p build && pip install build meson-python patchelf wheel && meson setup build && python -m build --wheel --no-isolation -Cbuilddir=build'; fail_fast: false - id: generate-compilation-database-asciidtype name: Generate compilation database [asciidtype] @@ -15,7 +15,7 @@ repos: language: system require_serial: true entry: | - bash -c 'cd asciidtype && mkdir -p build && pip install build meson-python patchelf wheel && python -m build --wheel --no-isolation -Cbuilddir=build'; + bash -c 'cd asciidtype && mkdir -p build && pip install build meson-python patchelf wheel && meson setup build && python -m build --wheel --no-isolation -Cbuilddir=build'; fail_fast: false - id: generate-compilation-database-quaddtype name: Generate compilation database [quaddtype] @@ -23,7 +23,7 @@ repos: language: system require_serial: true entry: | - bash -c 'cd quaddtype && mkdir -p build && pip install build meson-python patchelf wheel && python -m build --wheel --no-isolation -Cbuilddir=build'; + bash -c 'cd quaddtype && mkdir -p build && pip install build meson-python patchelf wheel && meson setup build && python -m build --wheel --no-isolation -Cbuilddir=build'; fail_fast: false - id: generate-compilation-database-unytdtype name: Generate compilation database [unytdtype] @@ -31,7 +31,7 @@ repos: language: system require_serial: true entry: | - bash -c 'cd unytdtype && mkdir -p build && pip install build meson-python patchelf wheel && python -m build --wheel --no-isolation -Cbuilddir=build'; + bash -c 'cd unytdtype && mkdir -p build && pip install build meson-python patchelf wheel && meson setup build && python -m build --wheel --no-isolation -Cbuilddir=build'; fail_fast: false - id: generate-compilation-database-stringdtype name: Generate compilation database [stringdtype] @@ -39,7 +39,7 @@ repos: language: system require_serial: true entry: | - bash -c 'cd stringdtype && mkdir -p build && pip install build meson-python patchelf wheel && python -m build --wheel --no-isolation -Cbuilddir=build'; + bash -c 'cd stringdtype && mkdir -p build && pip install build meson-python patchelf wheel && meson setup build && python -m build --wheel --no-isolation -Cbuilddir=build'; fail_fast: false - repo: https://github.com/pocc/pre-commit-hooks rev: v1.3.5 diff --git a/asciidtype/asciidtype/src/asciidtype_main.c b/asciidtype/asciidtype/src/asciidtype_main.c index 1574eb91..b341e859 100644 --- a/asciidtype/asciidtype/src/asciidtype_main.c +++ b/asciidtype/asciidtype/src/asciidtype_main.c @@ -22,7 +22,7 @@ PyInit__asciidtype_main(void) return NULL; } - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { return NULL; } diff --git a/asciidtype/meson.build b/asciidtype/meson.build index 2b566186..762b2b5d 100644 --- a/asciidtype/meson.build +++ b/asciidtype/meson.build @@ -35,7 +35,8 @@ py.install_sources( 'asciidtype/__init__.py', 'asciidtype/scalar.py' ], - subdir: 'asciidtype' + subdir: 'asciidtype', + pure: false ) py.extension_module( diff --git a/metadatadtype/meson.build b/metadatadtype/meson.build index e1165293..ff136af8 100644 --- a/metadatadtype/meson.build +++ b/metadatadtype/meson.build @@ -35,7 +35,8 @@ py.install_sources( 'metadatadtype/__init__.py', 'metadatadtype/scalar.py' ], - subdir: 'metadatadtype' + subdir: 'metadatadtype', + pure: false ) py.extension_module( diff --git a/metadatadtype/metadatadtype/src/metadatadtype_main.c b/metadatadtype/metadatadtype/src/metadatadtype_main.c index 2681af6c..e432accb 100644 --- a/metadatadtype/metadatadtype/src/metadatadtype_main.c +++ b/metadatadtype/metadatadtype/src/metadatadtype_main.c @@ -21,7 +21,7 @@ PyInit__metadatadtype_main(void) if (_import_array() < 0) { return NULL; } - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { return NULL; } diff --git a/mpfdtype/meson.build b/mpfdtype/meson.build index f67f4d48..729f3553 100644 --- a/mpfdtype/meson.build +++ b/mpfdtype/meson.build @@ -46,7 +46,8 @@ py.install_sources( [ 'mpfdtype/__init__.py', ], - subdir: 'mpfdtype' + subdir: 'mpfdtype', + pure: false ) py.extension_module( diff --git a/mpfdtype/mpfdtype/src/mpfdtype_main.c b/mpfdtype/mpfdtype/src/mpfdtype_main.c index 0f68b586..94ae37d7 100644 --- a/mpfdtype/mpfdtype/src/mpfdtype_main.c +++ b/mpfdtype/mpfdtype/src/mpfdtype_main.c @@ -22,7 +22,7 @@ PyInit__mpfdtype_main(void) if (_import_array() < 0) { return NULL; } - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { return NULL; } diff --git a/quaddtype/meson.build b/quaddtype/meson.build index 44bc889e..3c1b0712 100644 --- a/quaddtype/meson.build +++ b/quaddtype/meson.build @@ -33,7 +33,8 @@ py.install_sources( 'quaddtype/__init__.py', 'quaddtype/quadscalar.py' ], - subdir: 'quaddtype' + subdir: 'quaddtype', + pure: false ) py.extension_module( diff --git a/quaddtype/quaddtype/src/quaddtype_main.c b/quaddtype/quaddtype/src/quaddtype_main.c index 5f13c079..e7420946 100644 --- a/quaddtype/quaddtype/src/quaddtype_main.c +++ b/quaddtype/quaddtype/src/quaddtype_main.c @@ -23,7 +23,7 @@ PyInit__quaddtype_main(void) return NULL; // Fail to init if the experimental DType API version 5 isn't supported - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { PyErr_SetString(PyExc_ImportError, "Error encountered importing the experimental dtype API."); return NULL; diff --git a/stringdtype/meson.build b/stringdtype/meson.build index 64b6a114..71317790 100644 --- a/stringdtype/meson.build +++ b/stringdtype/meson.build @@ -35,9 +35,11 @@ srcs = [ py.install_sources( [ 'stringdtype/__init__.py', - 'stringdtype/scalar.py' + 'stringdtype/scalar.py', + 'stringdtype/missing.py', ], - subdir: 'stringdtype' + subdir: 'stringdtype', + pure: false ) py.extension_module( diff --git a/stringdtype/pyproject.toml b/stringdtype/pyproject.toml index 1fcf0803..3ad86297 100644 --- a/stringdtype/pyproject.toml +++ b/stringdtype/pyproject.toml @@ -31,6 +31,6 @@ per-file-ignores = {"__init__.py" = ["F401"]} [tool.meson-python.args] dist = [] -setup = ["-Ddebug=true", "-Doptimization=2"] +setup = ["-Ddebug=true", "-Doptimization=0"] compile = [] install = [] diff --git a/stringdtype/stringdtype/__init__.py b/stringdtype/stringdtype/__init__.py index 6fc58f31..f79e5ac0 100644 --- a/stringdtype/stringdtype/__init__.py +++ b/stringdtype/stringdtype/__init__.py @@ -2,10 +2,12 @@ """ +from .missing import NA # isort: skip from .scalar import StringScalar # isort: skip from ._main import StringDType, _memory_usage __all__ = [ + "NA", "StringDType", "StringScalar", "_memory_usage", diff --git a/stringdtype/stringdtype/missing.py b/stringdtype/stringdtype/missing.py new file mode 100644 index 00000000..69a547d3 --- /dev/null +++ b/stringdtype/stringdtype/missing.py @@ -0,0 +1,6 @@ +class NAType: + def __repr__(self): + return "stringdtype.NA" + + +NA = NAType() diff --git a/stringdtype/stringdtype/src/casts.c b/stringdtype/stringdtype/src/casts.c index 374948c2..dd1b8282 100644 --- a/stringdtype/stringdtype/src/casts.c +++ b/stringdtype/stringdtype/src/casts.c @@ -53,15 +53,13 @@ string_to_string(PyArrayMethod_Context *NPY_UNUSED(context), npy_intp out_stride = strides[1]; ss *s = NULL; + ss *os = NULL; while (N--) { - // *out* may be reallocated later; *in->buf* may point to a statically - // allocated empty ss struct, so we need to load the string into an - // intermediate buffer *s* to avoid the possibility of freeing static - // data later on. - load_string(in, (ss **)&s); - ssfree((ss *)out); - if (ssdup((ss *)s, (ss *)out) < 0) { + s = (ss *)in; + os = (ss *)out; + ssfree(os); + if (ssdup(s, os) < 0) { gil_error(PyExc_MemoryError, "ssdup failed"); return -1; } @@ -338,9 +336,18 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[], ss *s = NULL; while (N--) { - load_string(in, &s); - unsigned char *this_string = (unsigned char *)(s->buf); - size_t n_bytes = s->len; + s = (ss *)in; + unsigned char *this_string = NULL; + size_t n_bytes; + if (ss_isnull(s)) { + // lossy but not much else we can do + this_string = (unsigned char *)"NA"; + n_bytes = 3; + } + else { + this_string = (unsigned char *)(s->buf); + n_bytes = s->len; + } size_t tot_n_bytes = 0; for (int i = 0; i < max_out_size; i++) { @@ -401,7 +408,7 @@ string_to_bool_resolve_descriptors(PyObject *NPY_UNUSED(self), } static int -string_to_bool(PyArrayMethod_Context *context, char *const data[], +string_to_bool(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) { @@ -415,8 +422,12 @@ string_to_bool(PyArrayMethod_Context *context, char *const data[], ss *s = NULL; while (N--) { - load_string(in, &s); - if (s->len == 0) { + s = (ss *)in; + if (ss_isnull(s)) { + // numpy treats NaN as truthy, following python + *out = (npy_bool)1; + } + else if (s->len == 0) { *out = (npy_bool)0; } else { diff --git a/stringdtype/stringdtype/src/dtype.c b/stringdtype/stringdtype/src/dtype.c index c188d774..f73e1296 100644 --- a/stringdtype/stringdtype/src/dtype.c +++ b/stringdtype/stringdtype/src/dtype.c @@ -4,6 +4,8 @@ #include "static_string.h" PyTypeObject *StringScalar_Type = NULL; +static PyTypeObject *StringNA_Type = NULL; +static PyObject *NA_OBJ = NULL; /* * Internal helper to create new instances @@ -80,14 +82,35 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), static PyObject * get_value(PyObject *scalar) { - PyObject *ret_bytes = NULL; + PyObject *ret = NULL; PyTypeObject *scalar_type = Py_TYPE(scalar); // FIXME: handle bytes too if ((scalar_type == &PyUnicode_Type) || (scalar_type == StringScalar_Type)) { // attempt to decode as UTF8 - ret_bytes = PyUnicode_AsUTF8String(scalar); - if (ret_bytes == NULL) { + ret = PyUnicode_AsUTF8String(scalar); + if (ret == NULL) { + PyErr_SetString( + PyExc_TypeError, + "Can only store UTF8 text in a StringDType array."); + return NULL; + } + } + else if (scalar_type == StringNA_Type) { + ret = scalar; + Py_INCREF(ret); + } + // store np.nan as NA + else if (scalar_type == &PyFloat_Type) { + double scalar_val = PyFloat_AsDouble(scalar); + if ((scalar_val == -1.0) && PyErr_Occurred()) { + return NULL; + } + if (npy_isnan(scalar_val)) { + ret = NA_OBJ; + Py_INCREF(ret); + } + else { PyErr_SetString( PyExc_TypeError, "Can only store UTF8 text in a StringDType array."); @@ -99,7 +122,7 @@ get_value(PyObject *scalar) "Can only store String text in a StringDType array."); return NULL; } - return ret_bytes; + return ret; } // Take a python object `obj` and insert it into the array of dtype `descr` at @@ -109,58 +132,75 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj, char **dataptr) { PyObject *val_obj = get_value(obj); + if (val_obj == NULL) { return -1; } - char *val = NULL; - Py_ssize_t length = 0; - if (PyBytes_AsStringAndSize(val_obj, &val, &length) == -1) { - return -1; - } + ss *sdata = (ss *)dataptr; // free if dataptr holds preexisting string data, // ssfree does a NULL check - ssfree((ss *)dataptr); + ssfree(sdata); - // copies contents of val into item_val->buf - int res = ssnewlen(val, length, (ss *)dataptr); + // RichCompareBool short-circuits to a pointer comparison fast-path + // so no need to do pointer comparison first + int eq_res = PyObject_RichCompareBool(val_obj, NA_OBJ, Py_EQ); - // val_obj must stay alive until here to ensure *val* doesn't get - // deallocated - Py_DECREF(val_obj); + if (eq_res < 0) { + goto error; + } - if (res == -1) { - PyErr_NoMemory(); - return -1; + if (eq_res == 1) { + // do nothing, ssfree already NULLed the struct ssdata points to + // so it already contains a NA value } - else if (res == -2) { - // this should never happen - assert(0); + else { + char *val = NULL; + Py_ssize_t length = 0; + if (PyBytes_AsStringAndSize(val_obj, &val, &length) == -1) { + goto error; + } + + // copies contents of val into item_val->buf + int res = ssnewlen(val, length, sdata); + + if (res == -1) { + PyErr_NoMemory(); + goto error; + } + else if (res == -2) { + // this should never happen + assert(0); + goto error; + } } + Py_DECREF(val_obj); return 0; + +error: + Py_DECREF(val_obj); + return -1; } static PyObject * stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr) { - char *data; - size_t len; + PyObject *val_obj = NULL; + ss *sdata = (ss *)dataptr; - if (*dataptr == NULL) { - data = "\0"; - len = 0; + if (ss_isnull(sdata)) { + Py_INCREF(NA_OBJ); + val_obj = NA_OBJ; } else { - data = ((ss *)dataptr)->buf; - len = ((ss *)dataptr)->len; - } - - PyObject *val_obj = PyUnicode_FromStringAndSize(data, len); - - if (val_obj == NULL) { - return NULL; + char *data = sdata->buf; + size_t len = sdata->len; + val_obj = PyUnicode_FromStringAndSize(data, len); + if (val_obj == NULL) { + return NULL; + } } /* @@ -190,10 +230,16 @@ nonzero(void *data, void *NPY_UNUSED(arr)) int compare(void *a, void *b, void *NPY_UNUSED(arr)) { - ss *ss_a = NULL; - ss *ss_b = NULL; - load_string(a, &ss_a); - load_string(b, &ss_b); + ss *ss_a = (ss *)a; + ss *ss_b = (ss *)b; + int a_is_null = ss_isnull(ss_a); + int b_is_null = ss_isnull(ss_b); + if (a_is_null) { + return 1; + } + else if (b_is_null) { + return -1; + } return strcmp(ss_a->buf, ss_b->buf); } @@ -265,6 +311,35 @@ stringdtype_get_clear_loop(void *NPY_UNUSED(traverse_context), return 0; } +static int +stringdtype_fill_zero_loop(void *NPY_UNUSED(traverse_context), + PyArray_Descr *NPY_UNUSED(descr), char *data, + npy_intp size, npy_intp stride, + NpyAuxData *NPY_UNUSED(auxdata)) +{ + while (size--) { + if (ssnewlen("", 0, (ss *)(data)) < 0) { + return -1; + } + data += stride; + } + return 0; +} + +static int +stringdtype_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), + PyArray_Descr *NPY_UNUSED(descr), + int NPY_UNUSED(aligned), + npy_intp NPY_UNUSED(fixed_stride), + traverse_loop_function **out_loop, + NpyAuxData **NPY_UNUSED(out_auxdata), + NPY_ARRAYMETHOD_FLAGS *flags) +{ + *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + *out_loop = &stringdtype_fill_zero_loop; + return 0; +} + static PyType_Slot StringDType_Slots[] = { {NPY_DT_common_instance, &common_instance}, {NPY_DT_common_dtype, &common_dtype}, @@ -278,6 +353,7 @@ static PyType_Slot StringDType_Slots[] = { {NPY_DT_PyArray_ArrFuncs_argmax, &argmax}, {NPY_DT_PyArray_ArrFuncs_argmin, &argmin}, {NPY_DT_get_clear_loop, &stringdtype_get_clear_loop}, + {NPY_DT_get_fill_zero_loop, &stringdtype_get_fill_zero_loop}, {0, NULL}}; static PyObject * @@ -441,3 +517,12 @@ init_string_dtype(void) return 0; } + +int +init_string_na_object(PyObject *mod) +{ + NA_OBJ = PyObject_GetAttrString(mod, "NA"); + StringNA_Type = Py_TYPE(NA_OBJ); + Py_INCREF(StringNA_Type); + return 0; +} diff --git a/stringdtype/stringdtype/src/dtype.h b/stringdtype/stringdtype/src/dtype.h index 056c62e9..13071bba 100644 --- a/stringdtype/stringdtype/src/dtype.h +++ b/stringdtype/stringdtype/src/dtype.h @@ -12,6 +12,7 @@ #include "numpy/arrayobject.h" #include "numpy/experimental_dtype_api.h" #include "numpy/ndarraytypes.h" +#include "numpy/npy_math.h" typedef struct { PyArray_Descr base; @@ -29,6 +30,10 @@ init_string_dtype(void); int compare(void *, void *, void *); +int +init_string_na_object(PyObject *mod); + + // from dtypemeta.h, not public in numpy #define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr)) diff --git a/stringdtype/stringdtype/src/main.c b/stringdtype/stringdtype/src/main.c index 00f4d0e0..d53453a4 100644 --- a/stringdtype/stringdtype/src/main.c +++ b/stringdtype/stringdtype/src/main.c @@ -91,7 +91,7 @@ PyInit__main(void) if (_import_array() < 0) { return NULL; } - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { return NULL; } @@ -117,6 +117,10 @@ PyInit__main(void) goto error; } + if (init_string_na_object(mod) < 0) { + goto error; + } + Py_INCREF((PyObject *)&StringDType); if (PyModule_AddObject(m, "StringDType", (PyObject *)&StringDType) < 0) { Py_DECREF((PyObject *)&StringDType); diff --git a/stringdtype/stringdtype/src/static_string.c b/stringdtype/stringdtype/src/static_string.c index 143d6eaa..f423e6a0 100644 --- a/stringdtype/stringdtype/src/static_string.c +++ b/stringdtype/stringdtype/src/static_string.c @@ -1,12 +1,19 @@ #include "static_string.h" +static ss EMPTY = {0, ""}; + int ssnewlen(const char *init, size_t len, ss *to_init) { - if ((to_init->buf != NULL) || (to_init->len != 0)) { + if ((to_init == NULL) || (to_init->buf != NULL) || (to_init->len != 0)) { return -2; } + if (len == 0) { + to_init->len = 0; + to_init->buf = EMPTY.buf; + } + // one extra byte for null terminator char *ret_buf = (char *)malloc(sizeof(char) * (len + 1)); @@ -31,7 +38,9 @@ void ssfree(ss *str) { if (str->buf != NULL) { - free(str->buf); + if (str->buf != EMPTY.buf) { + free(str->buf); + } str->buf = NULL; } str->len = 0; @@ -40,7 +49,14 @@ ssfree(ss *str) int ssdup(ss *in, ss *out) { - return ssnewlen(in->buf, in->len, out); + if (ss_isnull(in)) { + out->len = 0; + out->buf = NULL; + return 0; + } + else { + return ssnewlen(in->buf, in->len, out); + } } int @@ -62,16 +78,11 @@ ssnewemptylen(size_t num_bytes, ss *out) return 0; } -static ss EMPTY = {0, "\0"}; - -void -load_string(char *data, ss **out) +int +ss_isnull(ss *in) { - ss *ss_d = (ss *)data; - if (ss_d->len == 0) { - *out = &EMPTY; - } - else { - *out = ss_d; + if (in->len == 0 && in->buf == NULL) { + return 1; } + return 0; } diff --git a/stringdtype/stringdtype/src/static_string.h b/stringdtype/stringdtype/src/static_string.h index c4a956d2..cab0a368 100644 --- a/stringdtype/stringdtype/src/static_string.h +++ b/stringdtype/stringdtype/src/static_string.h @@ -34,13 +34,9 @@ ssdup(ss *in, ss *out); int ssnewemptylen(size_t num_bytes, ss *out); -// Interpret the contents of buffer *data* as an ss struct and set *out* to -// that struct. If *data* is NULL, set *out* to point to a statically -// allocated, empty SS struct. Since this function may set *out* to point to -// statically allocated data, do not ever free memory owned by an output of -// this function. That means this function is most useful for read-only -// applications. -void -load_string(char *data, ss **out); +// Determine if *in* corresponds to a NULL ss struct (e.g. len is zero and buf +// is NULL. Returns 1 if this is the case and zero otherwise. Cannot fail. +int +ss_isnull(ss *in); #endif /*_NPY_STATIC_STRING_H */ diff --git a/stringdtype/stringdtype/src/umath.c b/stringdtype/stringdtype/src/umath.c index 30330962..3caea02e 100644 --- a/stringdtype/stringdtype/src/umath.c +++ b/stringdtype/stringdtype/src/umath.c @@ -142,10 +142,14 @@ string_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context), ss *s1 = NULL, *s2 = NULL; while (N--) { - load_string(in1, &s1); - load_string(in2, &s2); - - if (s1->len == s2->len && strncmp(s1->buf, s2->buf, s1->len) == 0) { + s1 = (ss *)in1; + s2 = (ss *)in2; + if (ss_isnull(s1) || ss_isnull(s2)) { + // s1 or s2 is NA + *out = (npy_bool)0; + } + else if (s1->len == s2->len && + strncmp(s1->buf, s2->buf, s1->len) == 0) { *out = (npy_bool)1; } else { @@ -183,14 +187,23 @@ string_isnan_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context), NpyAuxData *NPY_UNUSED(auxdata)) { npy_intp N = dimensions[0]; + char *in = data[0]; npy_bool *out = (npy_bool *)data[1]; + npy_intp in_stride = strides[0]; npy_intp out_stride = strides[1]; + ss *s = NULL; + while (N--) { - // we could represent missing data with a null pointer, but - // should isnan return True in that case? - *out = (npy_bool)0; + s = (ss *)in; + if (ss_isnull(s)) { + *out = (npy_bool)1; + } + else { + *out = (npy_bool)0; + } + in += in_stride; out += out_stride; } diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index 136320bc..f6283331 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from stringdtype import StringDType, StringScalar, _memory_usage +from stringdtype import NA, StringDType, StringScalar, _memory_usage @pytest.fixture @@ -121,9 +121,9 @@ def test_equality_promotion(string_list): def test_isnan(string_list): - sarr = np.array(string_list, dtype=StringDType()) + sarr = np.array(string_list + [NA], dtype=StringDType()) np.testing.assert_array_equal( - np.isnan(sarr), np.zeros_like(sarr, dtype=np.bool_) + np.isnan(sarr), np.array([0] * len(string_list) + [1], dtype=np.bool_) ) @@ -185,12 +185,24 @@ def test_pickle(string_list): ) def test_sort(strings): """Test that sorting matches python's internal sorting.""" - arr = np.array(strings, dtype=StringDType()) + + def test_sort(strings, arr_sorted): + arr = np.array(strings, dtype=StringDType()) + np.random.default_rng().shuffle(arr) + arr.sort() + assert np.array_equal(arr, arr_sorted, equal_nan=True) + arr_sorted = np.array(sorted(strings), dtype=StringDType()) + test_sort(strings, arr_sorted) + + # make sure NAs get sorted to the end of the array + strings.insert(0, NA) + strings.insert(2, NA) + # can't use append because doing that with NA converts + # the result to object dtype + arr_sorted = np.array(arr_sorted.tolist() + [NA, NA], dtype=StringDType()) - np.random.default_rng().shuffle(arr) - arr.sort() - np.testing.assert_array_equal(arr, arr_sorted) + test_sort(strings, arr_sorted) @pytest.mark.parametrize( @@ -212,12 +224,11 @@ def test_creation_functions(): np.zeros(3, dtype=StringDType()), ["", "", ""] ) - np.testing.assert_array_equal( - np.empty(3, dtype=StringDType()), ["", "", ""] - ) + assert np.zeros(3, dtype=StringDType())[0] == "" + + assert np.all(np.isnan(np.empty(3, dtype=StringDType()))) - # make sure getitem works too - assert np.empty(3, dtype=StringDType())[0] == "" + assert np.empty(3, dtype=StringDType())[0] is NA def test_is_numeric(): @@ -244,14 +255,14 @@ def test_argmax(strings): @pytest.mark.parametrize( "arrfunc,expected", [ - [np.sort, np.empty(10, dtype=StringDType())], + [np.sort, np.zeros(10, dtype=StringDType())], [np.nonzero, (np.array([], dtype=np.int64),)], [np.argmax, 0], [np.argmin, 0], ], ) -def test_arrfuncs_empty(arrfunc, expected): - arr = np.empty(10, dtype=StringDType()) +def test_arrfuncs_zeros(arrfunc, expected): + arr = np.zeros(10, dtype=StringDType()) result = arrfunc(arr) np.testing.assert_array_equal(result, expected, strict=True) @@ -316,3 +327,14 @@ def test_ufunc_add(string_list, other_strings): np.add(arr1, arr2), np.array([a + b for a, b in zip(arr1, arr2)], dtype=StringDType()), ) + + +@pytest.mark.parametrize("na_val", [float("nan"), np.nan, NA]) +def test_create_with_na(na_val): + string_list = ["hello", na_val, "world"] + arr = np.array(string_list, dtype=StringDType()) + assert ( + repr(arr) + == "array(['hello', stringdtype.NA, 'world'], dtype=StringDType())" + ) + assert arr[1] == NA and arr[1] is NA diff --git a/unytdtype/meson.build b/unytdtype/meson.build index e14e7a9f..2857b288 100644 --- a/unytdtype/meson.build +++ b/unytdtype/meson.build @@ -35,7 +35,8 @@ py.install_sources( 'unytdtype/__init__.py', 'unytdtype/scalar.py' ], - subdir: 'unytdtype' + subdir: 'unytdtype', + pure: false ) py.extension_module( diff --git a/unytdtype/unytdtype/src/unytdtype_main.c b/unytdtype/unytdtype/src/unytdtype_main.c index 26e116ba..541d1e2c 100644 --- a/unytdtype/unytdtype/src/unytdtype_main.c +++ b/unytdtype/unytdtype/src/unytdtype_main.c @@ -21,7 +21,7 @@ PyInit__unytdtype_main(void) if (_import_array() < 0) { return NULL; } - if (import_experimental_dtype_api(9) < 0) { + if (import_experimental_dtype_api(10) < 0) { return NULL; }