Skip to content

Commit c72e3ee

Browse files
TomAugspurgerjorisvandenbossche
authored andcommitted
Use NA scalar in string dtype (#1)
1 parent 1849a23 commit c72e3ee

File tree

9 files changed

+59
-30
lines changed

9 files changed

+59
-30
lines changed

pandas/_libs/lib.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ cdef class Validator:
15001500
f'must define is_value_typed')
15011501

15021502
cdef bint is_valid_null(self, object value) except -1:
1503-
return value is None or util.is_nan(value)
1503+
return value is None or value is C_NA or util.is_nan(value)
15041504

15051505
cdef bint is_array_typed(self) except -1:
15061506
return False

pandas/_libs/testing.pyx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,15 @@ cpdef assert_almost_equal(a, b,
180180
# classes can't be the same, to raise error
181181
assert_class_equal(a, b, obj=obj)
182182

183-
if a == b:
184-
# object comparison
185-
return True
186183
if isna(a) and isna(b):
187184
# TODO: Should require same-dtype NA?
188185
# nan / None comparison
189186
return True
187+
188+
if a == b:
189+
# object comparison
190+
return True
191+
190192
if is_comparable_as_number(a) and is_comparable_as_number(b):
191193
if array_equivalent(a, b, strict_nan=True):
192194
# inf comparison

pandas/core/arrays/numpy_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def fillna(self, value=None, method=None, limit=None):
278278
return new_values
279279

280280
def take(self, indices, allow_fill=False, fill_value=None):
281+
if fill_value is None:
282+
# Primarily for subclasses
283+
fill_value = self.dtype.na_value
281284
result = take(
282285
self._ndarray, indices, allow_fill=allow_fill, fill_value=fill_value
283286
)

pandas/core/arrays/string_.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import operator
2-
from typing import TYPE_CHECKING, Type
2+
from typing import Type
33

44
import numpy as np
55

6-
from pandas._libs import lib
6+
from pandas._libs import lib, missing as libmissing
77

88
from pandas.core.dtypes.base import ExtensionDtype
99
from pandas.core.dtypes.common import pandas_dtype
@@ -17,9 +17,6 @@
1717
from pandas.core.construction import extract_array
1818
from pandas.core.missing import isna
1919

20-
if TYPE_CHECKING:
21-
from pandas._typing import Scalar
22-
2320

2421
@register_extension_dtype
2522
class StringDtype(ExtensionDtype):
@@ -50,16 +47,8 @@ class StringDtype(ExtensionDtype):
5047
StringDtype
5148
"""
5249

53-
@property
54-
def na_value(self) -> "Scalar":
55-
"""
56-
StringDtype uses :attr:`numpy.nan` as the missing NA value.
57-
58-
.. warning::
59-
60-
`na_value` may change in a future release.
61-
"""
62-
return np.nan
50+
#: StringDtype.na_value uses pandas.NA
51+
na_value = libmissing.NA
6352

6453
@property
6554
def type(self) -> Type:
@@ -172,10 +161,10 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
172161
if dtype:
173162
assert dtype == "string"
174163
result = super()._from_sequence(scalars, dtype=object, copy=copy)
175-
# convert None to np.nan
164+
# Standardize all missing-like values to NA
176165
# TODO: it would be nice to do this in _validate / lib.is_string_array
177166
# We are already doing a scan over the values there.
178-
result[result.isna()] = np.nan
167+
result[result.isna()] = StringDtype.na_value
179168
return result
180169

181170
@classmethod
@@ -192,6 +181,12 @@ def __arrow_array__(self, type=None):
192181
type = pa.string()
193182
return pa.array(self._ndarray, type=type, from_pandas=True)
194183

184+
def _values_for_factorize(self):
185+
arr = self._ndarray.copy()
186+
mask = self.isna()
187+
arr[mask] = -1
188+
return arr, -1
189+
195190
def __setitem__(self, key, value):
196191
value = extract_array(value, extract_numpy=True)
197192
if isinstance(value, type(self)):
@@ -205,9 +200,9 @@ def __setitem__(self, key, value):
205200

206201
# validate new items
207202
if scalar_value:
208-
if scalar_value is None:
209-
value = np.nan
210-
elif not (isinstance(value, str) or np.isnan(value)):
203+
if isna(value):
204+
value = StringDtype.na_value
205+
elif not isinstance(value, str):
211206
raise ValueError(
212207
"Cannot set non-string value '{}' into a StringArray.".format(value)
213208
)
@@ -265,7 +260,7 @@ def method(self, other):
265260
other = other[valid]
266261

267262
result = np.empty_like(self._ndarray, dtype="object")
268-
result[mask] = np.nan
263+
result[mask] = StringDtype.na_value
269264
result[valid] = op(self._ndarray[valid], other)
270265

271266
if op.__name__ in {"add", "radd", "mul", "rmul"}:

pandas/core/dtypes/missing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ def array_equivalent(left, right, strict_nan=False):
446446
if left_value is NaT and right_value is not NaT:
447447
return False
448448

449+
elif left_value is libmissing.NA and right_value is not libmissing.NA:
450+
return False
451+
449452
elif isinstance(left_value, float) and np.isnan(left_value):
450453
if not isinstance(right_value, float) or not np.isnan(right_value):
451454
return False
@@ -457,6 +460,8 @@ def array_equivalent(left, right, strict_nan=False):
457460
if "Cannot compare tz-naive" in str(err):
458461
# tzawareness compat failure, see GH#28507
459462
return False
463+
elif "boolean value of NA is ambiguous" in str(err):
464+
return False
460465
raise
461466
return True
462467

pandas/core/strings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ def _map(f, arr, na_mask=False, na_value=np.nan, dtype=object):
123123
if na_mask:
124124
mask = isna(arr)
125125
convert = not np.all(mask)
126+
if convert:
127+
# XXX: This converts pd.NA to np.nan to match the output of
128+
# object-dtype ops that return numeric, like str.count
129+
# We probably want to return Int64Dtype instead.
130+
# NA -> nan
131+
arr[mask] = np.nan
132+
126133
try:
127134
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert)
128135
except (TypeError, AttributeError) as e:

pandas/tests/arrays/string_/test_string.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def test_none_to_nan():
1313
a = pd.arrays.StringArray._from_sequence(["a", None, "b"])
1414
assert a[1] is not None
15-
assert np.isnan(a[1])
15+
assert a[1] is pd.NA
1616

1717

1818
def test_setitem_validates():
@@ -24,6 +24,15 @@ def test_setitem_validates():
2424
a[:] = np.array([1, 2])
2525

2626

27+
def test_setitem_with_scalar_string():
28+
# is_float_dtype considers some strings, like 'd', to be floats
29+
# which can cause issues.
30+
arr = pd.array(["a", "c"], dtype="string")
31+
arr[0] = "d"
32+
expected = pd.array(["d", "c"], dtype="string")
33+
tm.assert_extension_array_equal(arr, expected)
34+
35+
2736
@pytest.mark.parametrize(
2837
"input, method",
2938
[

pandas/tests/extension/test_string.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def data():
2525
@pytest.fixture
2626
def data_missing():
2727
"""Length 2 array with [NA, Valid]"""
28-
return StringArray._from_sequence([np.nan, "A"])
28+
return StringArray._from_sequence([pd.NA, "A"])
2929

3030

3131
@pytest.fixture
@@ -35,17 +35,17 @@ def data_for_sorting():
3535

3636
@pytest.fixture
3737
def data_missing_for_sorting():
38-
return StringArray._from_sequence(["B", np.nan, "A"])
38+
return StringArray._from_sequence(["B", pd.NA, "A"])
3939

4040

4141
@pytest.fixture
4242
def na_value():
43-
return np.nan
43+
return pd.NA
4444

4545

4646
@pytest.fixture
4747
def data_for_grouping():
48-
return StringArray._from_sequence(["B", "B", np.nan, np.nan, "A", "A", "B", "C"])
48+
return StringArray._from_sequence(["B", "B", pd.NA, pd.NA, "A", "A", "B", "C"])
4949

5050

5151
class TestDtype(base.BaseDtypeTests):

pandas/tests/test_strings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,3 +3507,11 @@ def test_string_array(any_string_method):
35073507
assert all(result[columns].dtypes == "string")
35083508
result[columns] = result[columns].astype(object)
35093509
tm.assert_equal(result, expected)
3510+
3511+
3512+
@pytest.mark.xfail(reason="not implmented yet")
3513+
def test_string_dtype_numeric():
3514+
s = Series(["a", "aa", None], dtype="string")
3515+
result = s.str.count("a")
3516+
expected = Series([1, 2, None], dtype="Int64")
3517+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)