Skip to content

Commit df8ac2f

Browse files
authored
ENH: implement EA._where (#44187)
1 parent 9e398a6 commit df8ac2f

File tree

6 files changed

+39
-32
lines changed

6 files changed

+39
-32
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def putmask(self, mask: np.ndarray, value) -> None:
320320

321321
np.putmask(self._ndarray, mask, value)
322322

323-
def where(
323+
def _where(
324324
self: NDArrayBackedExtensionArrayT, mask: np.ndarray, value
325325
) -> NDArrayBackedExtensionArrayT:
326326
"""

pandas/core/arrays/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,31 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT:
14111411

14121412
return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]])
14131413

1414+
def _where(
1415+
self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value
1416+
) -> ExtensionArrayT:
1417+
"""
1418+
Analogue to np.where(mask, self, value)
1419+
1420+
Parameters
1421+
----------
1422+
mask : np.ndarray[bool]
1423+
value : scalar or listlike
1424+
1425+
Returns
1426+
-------
1427+
same type as self
1428+
"""
1429+
result = self.copy()
1430+
1431+
if is_list_like(value):
1432+
val = value[~mask]
1433+
else:
1434+
val = value
1435+
1436+
result[~mask] = val
1437+
return result
1438+
14141439
@classmethod
14151440
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
14161441
"""

pandas/core/arrays/sparse/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,13 @@ def to_dense(self) -> np.ndarray:
13071307

13081308
_internal_get_values = to_dense
13091309

1310+
def _where(self, mask, value):
1311+
# NB: may not preserve dtype, e.g. result may be Sparse[float64]
1312+
# while self is Sparse[int64]
1313+
naive_implementation = np.where(mask, self, value)
1314+
result = type(self)._from_sequence(naive_implementation)
1315+
return result
1316+
13101317
# ------------------------------------------------------------------------
13111318
# IO
13121319
# ------------------------------------------------------------------------

pandas/core/internals/blocks.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
is_extension_array_dtype,
5151
is_interval_dtype,
5252
is_list_like,
53-
is_sparse,
5453
is_string_dtype,
5554
)
5655
from pandas.core.dtypes.dtypes import (
@@ -1596,30 +1595,9 @@ def where(self, other, cond, errors="raise") -> list[Block]:
15961595
# for the type
15971596
other = self.dtype.na_value
15981597

1599-
if is_sparse(self.values):
1600-
# TODO(SparseArray.__setitem__): remove this if condition
1601-
# We need to re-infer the type of the data after doing the
1602-
# where, for cases where the subtypes don't match
1603-
dtype = None
1604-
else:
1605-
dtype = self.dtype
1606-
1607-
result = self.values.copy()
1608-
icond = ~cond
1609-
if lib.is_scalar(other):
1610-
set_other = other
1611-
else:
1612-
set_other = other[icond]
16131598
try:
1614-
result[icond] = set_other
1615-
except (NotImplementedError, TypeError):
1616-
# NotImplementedError for class not implementing `__setitem__`
1617-
# TypeError for SparseArray, which implements just to raise
1618-
# a TypeError
1619-
if isinstance(result, Categorical):
1620-
# TODO: don't special-case
1621-
raise
1622-
1599+
result = self.values._where(cond, other)
1600+
except TypeError:
16231601
if is_interval_dtype(self.dtype):
16241602
# TestSetitemFloatIntervalWithIntIntervalValues
16251603
blk = self.coerce_to_target_dtype(other)
@@ -1628,10 +1606,7 @@ def where(self, other, cond, errors="raise") -> list[Block]:
16281606
# Interval[int64]->Interval[float64]
16291607
raise
16301608
return blk.where(other, cond, errors)
1631-
1632-
result = type(self.values)._from_sequence(
1633-
np.where(cond, self.values, other), dtype=dtype
1634-
)
1609+
raise
16351610

16361611
return [self.make_block_same_class(result)]
16371612

@@ -1721,7 +1696,7 @@ def where(self, other, cond, errors="raise") -> list[Block]:
17211696
cond = extract_bool_array(cond)
17221697

17231698
try:
1724-
res_values = arr.T.where(cond, other).T
1699+
res_values = arr.T._where(cond, other).T
17251700
except (ValueError, TypeError):
17261701
return Block.where(self, other, cond, errors=errors)
17271702

pandas/tests/indexes/categorical/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_where_non_categories(self):
325325
msg = "Cannot setitem on a Categorical with a new category"
326326
with pytest.raises(TypeError, match=msg):
327327
# Test the Categorical method directly
328-
ci._data.where(mask, 2)
328+
ci._data._where(mask, 2)
329329

330330

331331
class TestContains:

pandas/tests/series/indexing/test_where.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_where_datetimelike_categorical(tz_naive_fixture):
459459
tm.assert_index_equal(res, dr)
460460

461461
# DatetimeArray.where
462-
res = lvals._data.where(mask, rvals)
462+
res = lvals._data._where(mask, rvals)
463463
tm.assert_datetime_array_equal(res, dr._data)
464464

465465
# Series.where

0 commit comments

Comments
 (0)