Skip to content

Commit a82f905

Browse files
authored
PERF: ArrowExtensionArray.searchsorted (#50447)
* PERF: ArrowExtensionArray.searchsorted * gh ref * fix * add test for searchsorted when array contains pd.NA * remove print statement * move tests
1 parent 1d73a56 commit a82f905

File tree

5 files changed

+81
-1
lines changed

5 files changed

+81
-1
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ Performance improvements
757757
- Performance improvement in :meth:`MultiIndex.putmask` (:issue:`49830`)
758758
- Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`)
759759
- Performance improvement in :meth:`Series.rank` for pyarrow-backed dtypes (:issue:`50264`)
760+
- Performance improvement in :meth:`Series.searchsorted` for pyarrow-backed dtypes (:issue:`50447`)
760761
- Performance improvement in :meth:`Series.fillna` for extension array dtypes (:issue:`49722`, :issue:`50078`)
761762
- Performance improvement in :meth:`Index.join`, :meth:`Index.intersection` and :meth:`Index.union` for masked dtypes when :class:`Index` is monotonic (:issue:`50310`)
762763
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)

pandas/core/arrays/arrow/array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (
44
TYPE_CHECKING,
55
Any,
6+
Literal,
67
TypeVar,
78
cast,
89
)
@@ -116,6 +117,11 @@ def floordiv_compat(
116117
}
117118

118119
if TYPE_CHECKING:
120+
from pandas._typing import (
121+
NumpySorter,
122+
NumpyValueArrayLike,
123+
)
124+
119125
from pandas import Series
120126

121127
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
@@ -693,6 +699,23 @@ def round(
693699
"""
694700
return type(self)(pc.round(self._data, ndigits=decimals))
695701

702+
@doc(ExtensionArray.searchsorted)
703+
def searchsorted(
704+
self,
705+
value: NumpyValueArrayLike | ExtensionArray,
706+
side: Literal["left", "right"] = "left",
707+
sorter: NumpySorter = None,
708+
) -> npt.NDArray[np.intp] | np.intp:
709+
if self._hasna:
710+
raise ValueError(
711+
"searchsorted requires array to be sorted, which is impossible "
712+
"with NAs present."
713+
)
714+
if isinstance(value, ExtensionArray):
715+
value = value.astype(object)
716+
# Base class searchsorted would cast to object, which is *much* slower.
717+
return self.to_numpy().searchsorted(value, side=side, sorter=sorter)
718+
696719
def take(
697720
self,
698721
indices: TakeIndexer,

pandas/core/arrays/string_.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import (
4+
TYPE_CHECKING,
5+
Literal,
6+
)
47

58
import numpy as np
69

@@ -54,6 +57,11 @@
5457
if TYPE_CHECKING:
5558
import pyarrow
5659

60+
from pandas._typing import (
61+
NumpySorter,
62+
NumpyValueArrayLike,
63+
)
64+
5765
from pandas import Series
5866

5967

@@ -492,6 +500,20 @@ def memory_usage(self, deep: bool = False) -> int:
492500
return result + lib.memory_usage_of_objects(self._ndarray)
493501
return result
494502

503+
@doc(ExtensionArray.searchsorted)
504+
def searchsorted(
505+
self,
506+
value: NumpyValueArrayLike | ExtensionArray,
507+
side: Literal["left", "right"] = "left",
508+
sorter: NumpySorter = None,
509+
) -> npt.NDArray[np.intp] | np.intp:
510+
if self._hasna:
511+
raise ValueError(
512+
"searchsorted requires array to be sorted, which is impossible "
513+
"with NAs present."
514+
)
515+
return super().searchsorted(value=value, side=side, sorter=sorter)
516+
495517
def _cmp_method(self, other, op):
496518
from pandas.arrays import BooleanArray
497519

pandas/tests/extension/test_arrow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,3 +1553,20 @@ def test_round():
15531553
result = ser.round(-1)
15541554
expected = pd.Series([120.0, pd.NA, 60.0], dtype=dtype)
15551555
tm.assert_series_equal(result, expected)
1556+
1557+
1558+
def test_searchsorted_with_na_raises(data_for_sorting, as_series):
1559+
# GH50447
1560+
b, c, a = data_for_sorting
1561+
arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]
1562+
arr[-1] = pd.NA
1563+
1564+
if as_series:
1565+
arr = pd.Series(arr)
1566+
1567+
msg = (
1568+
"searchsorted requires array to be sorted, "
1569+
"which is impossible with NAs present."
1570+
)
1571+
with pytest.raises(ValueError, match=msg):
1572+
arr.searchsorted(b)

pandas/tests/extension/test_string.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,20 @@ def arrow_not_supported(self, data, request):
420420
reason="2D support not implemented for ArrowStringArray"
421421
)
422422
request.node.add_marker(mark)
423+
424+
425+
def test_searchsorted_with_na_raises(data_for_sorting, as_series):
426+
# GH50447
427+
b, c, a = data_for_sorting
428+
arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]
429+
arr[-1] = pd.NA
430+
431+
if as_series:
432+
arr = pd.Series(arr)
433+
434+
msg = (
435+
"searchsorted requires array to be sorted, "
436+
"which is impossible with NAs present."
437+
)
438+
with pytest.raises(ValueError, match=msg):
439+
arr.searchsorted(b)

0 commit comments

Comments
 (0)