Skip to content

Commit 7edf3e0

Browse files
committed
Respond to review comments
1 parent 1f5c191 commit 7edf3e0

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
152152
}
153153

154154
if (eq_res == 1) {
155+
// NULL is the representation of NA in the array buffer
155156
sdata = NULL;
156157
}
157158
else {
@@ -233,17 +234,11 @@ compare(void *a, void *b, void *NPY_UNUSED(arr))
233234
ss *ss_b = (ss *)b;
234235
int a_is_null = ss_isnull(ss_a);
235236
int b_is_null = ss_isnull(ss_b);
236-
if (a_is_null || b_is_null) {
237-
// numpy sorts NaNs to the end of the array
238-
// pandas sorts NAs to the end as well
239-
// so we follow that behavior here
240-
if (!b_is_null) {
241-
return 1;
242-
}
243-
else if (!a_is_null) {
244-
return -1;
245-
}
246-
return 0;
237+
if (a_is_null) {
238+
return 1;
239+
}
240+
else if (b_is_null) {
241+
return -1;
247242
}
248243
return strcmp(ss_a->buf, ss_b->buf);
249244
}

stringdtype/tests/test_stringdtype.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,24 @@ def test_pickle(string_list):
185185
)
186186
def test_sort(strings):
187187
"""Test that sorting matches python's internal sorting."""
188-
arr = np.array(strings, dtype=StringDType())
188+
189+
def test_sort(strings, arr_sorted):
190+
arr = np.array(strings, dtype=StringDType())
191+
np.random.default_rng().shuffle(arr)
192+
arr.sort()
193+
assert np.array_equal(arr, arr_sorted, equal_nan=True)
194+
189195
arr_sorted = np.array(sorted(strings), dtype=StringDType())
196+
test_sort(strings, arr_sorted)
197+
198+
# make sure NAs get sorted to the end of the array
199+
strings.insert(0, NA)
200+
strings.insert(2, NA)
201+
# can't use append because doing that with NA converts
202+
# the result to object dtype
203+
arr_sorted = np.array(arr_sorted.tolist() + [NA, NA], dtype=StringDType())
190204

191-
np.random.default_rng().shuffle(arr)
192-
arr.sort()
193-
np.testing.assert_array_equal(arr, arr_sorted)
205+
test_sort(strings, arr_sorted)
194206

195207

196208
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)