Skip to content

Commit 3f841c7

Browse files
committed
collect into one function
1 parent b96bf59 commit 3f841c7

File tree

6 files changed

+128
-93
lines changed

6 files changed

+128
-93
lines changed

pandas/core/arrays/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -555,17 +555,17 @@ def searchsorted(self, value, side="left", sorter=None):
555555
.. versionadded:: 0.24.0
556556
557557
Find the indices into a sorted array `self` (a) such that, if the
558-
corresponding elements in `v` were inserted before the indices, the
559-
order of `self` would be preserved.
558+
corresponding elements in `value` were inserted before the indices,
559+
the order of `self` would be preserved.
560560
561-
Assuming that `a` is sorted:
561+
Assuming that `self` is sorted:
562562
563-
====== ============================
563+
====== ================================
564564
`side` returned index `i` satisfies
565-
====== ============================
566-
left ``self[i-1] < v <= self[i]``
567-
right ``self[i-1] <= v < self[i]``
568-
====== ============================
565+
====== ================================
566+
left ``self[i-1] < value <= self[i]``
567+
right ``self[i-1] <= value < self[i]``
568+
====== ================================
569569
570570
Parameters
571571
----------
@@ -581,7 +581,7 @@ def searchsorted(self, value, side="left", sorter=None):
581581
582582
Returns
583583
-------
584-
indices : array of ints
584+
array of ints
585585
Array of insertion points with the same shape as `value`.
586586
587587
See Also

pandas/core/arrays/numpy_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from pandas._libs import lib
66
from pandas.compat.numpy import function as nv
7+
from pandas.util._decorators import Appender
78
from pandas.util._validators import validate_fillna_kwargs
89

910
from pandas.core.dtypes.dtypes import ExtensionDtype
1011
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
1112
from pandas.core.dtypes.inference import is_array_like, is_list_like
1213

1314
from pandas import compat
14-
from pandas.core import nanops
15+
from pandas.core import common as com, nanops
1516
from pandas.core.missing import backfill_1d, pad_1d
1617

1718
from .base import ExtensionArray, ExtensionOpsMixin
@@ -423,6 +424,11 @@ def to_numpy(self, dtype=None, copy=False):
423424

424425
return result
425426

427+
@Appender(ExtensionArray.searchsorted.__doc__)
428+
def searchsorted(self, value, side='left', sorter=None):
429+
return com.searchsorted(self.to_numpy(), value,
430+
side=side, sorter=sorter)
431+
426432
# ------------------------------------------------------------------------
427433
# Ops
428434

pandas/core/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,15 +1514,11 @@ def factorize(self, sort=False, na_sentinel=-1):
15141514
array([3])
15151515
""")
15161516

1517-
@Substitution(klass='IndexOpsMixin')
1517+
@Substitution(klass='Index')
15181518
@Appender(_shared_docs['searchsorted'])
15191519
def searchsorted(self, value, side='left', sorter=None):
1520-
result = com.searchsorted(self._values, value,
1521-
side=side, sorter=sorter)
1522-
1523-
if is_scalar(value):
1524-
return result if is_scalar(result) else result[0]
1525-
return result
1520+
return com.searchsorted(self._values, value,
1521+
side=side, sorter=sorter)
15261522

15271523
def drop_duplicates(self, keep='first', inplace=False):
15281524
inplace = validate_bool_kwarg(inplace, 'inplace')

pandas/core/common.py

Lines changed: 64 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
import numpy as np
1414

1515
from pandas._libs import lib, tslibs
16+
from pandas.compat import PY36, OrderedDict, iteritems
17+
1618
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
17-
from pandas import compat
18-
from pandas.compat import iteritems, PY36, OrderedDict
19-
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
2019
from pandas.core.dtypes.common import (
21-
is_integer, is_integer_dtype, is_bool_dtype,
22-
is_extension_array_dtype, is_array_like, is_object_dtype,
23-
is_categorical_dtype, is_numeric_dtype, is_scalar, ensure_platform_int)
20+
ensure_platform_int, is_array_like, is_bool_dtype, is_categorical_dtype,
21+
is_extension_array_dtype, is_integer, is_integer_dtype, is_numeric_dtype,
22+
is_object_dtype, is_scalar)
23+
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
2424
from pandas.core.dtypes.inference import _iterable_not_string
2525
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2626

27+
from pandas import compat
28+
2729

2830
class SettingWithCopyError(ValueError):
2931
pass
@@ -485,87 +487,79 @@ def f(x):
485487
return f
486488

487489

488-
def searchsorted_integer(arr, value, side="left", sorter=None):
489-
"""
490-
searchsorted implementation for searching integer arrays.
491-
492-
We get a speedup if we ensure the dtype of arr and value are the same
493-
(if possible) before searchingm as numpy implicitly converts the dtypes
494-
if they're different, which would cause a slowdown.
495-
496-
See :func:`searchsorted` for a more general searchsorted implementation.
497-
498-
Parameters
499-
----------
500-
arr : numpy.array
501-
a numpy array of integers
502-
value : int or numpy.array
503-
an integer or an array of integers that we want to find the
504-
location(s) for in `arr`
505-
side : str
506-
One of {'left', 'right'}
507-
sorter : numpy.array, optional
508-
509-
Returns
510-
-------
511-
int or numpy.array
512-
The locations(s) of `value` in `arr`.
513-
"""
514-
from .arrays.array_ import array
515-
if sorter is not None:
516-
sorter = ensure_platform_int(sorter)
517-
518-
# below we try to give `value` the same dtype as `arr`, while guarding
519-
# against integer overflows. If the value of `value` is outside of the
520-
# bound of `arr`, `arr` would be recast by numpy, causing a slower search.
521-
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
522-
iinfo = np.iinfo(arr.dtype.type)
523-
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
524-
dtype = arr.dtype
525-
else:
526-
dtype = value_arr.dtype
527-
528-
if is_scalar(value):
529-
value = dtype.type(value)
530-
else:
531-
value = array(value, dtype=dtype)
532-
533-
return arr.searchsorted(value, side=side, sorter=sorter)
534-
535-
536490
def searchsorted(arr, value, side="left", sorter=None):
537491
"""
538492
Find indices where elements should be inserted to maintain order.
539493
540-
Find the indices into a sorted array-like `arr` such that, if the
494+
.. versionadded:: 0.25.0
495+
496+
Find the indices into a sorted array `self` (a) such that, if the
541497
corresponding elements in `value` were inserted before the indices,
542-
the order of `arr` would be preserved.
498+
the order of `self` would be preserved.
499+
500+
Assuming that `self` is sorted:
543501
544-
See :class:`IndexOpsMixin.searchsorted` for more details and examples.
502+
====== ================================
503+
`side` returned index `i` satisfies
504+
====== ================================
505+
left ``self[i-1] < value <= self[i]``
506+
right ``self[i-1] <= value < self[i]``
507+
====== ================================
545508
546509
Parameters
547510
----------
548-
arr : numpy.array or ExtensionArray
549-
value : scalar or numpy.array
550-
side : str
551-
One of {'left', 'right'}
552-
sorter : numpy.array, optional
511+
arr: numpy.array or ExtensionArray
512+
array to search in. Cannot be Index, Series or PandasArray, as that
513+
would cause a RecursionError.
514+
value : array_like
515+
Values to insert into `arr`.
516+
side : {'left', 'right'}, optional
517+
If 'left', the index of the first suitable location found is given.
518+
If 'right', return the last such index. If there is no suitable
519+
index, return either 0 or N (where N is the length of `self`).
520+
sorter : 1-D array_like, optional
521+
Optional array of integer indices that sort array a into ascending
522+
order. They are typically the result of argsort.
553523
554524
Returns
555525
-------
556-
int or numpy.array
557-
The locations(s) of `value` in `arr`.
526+
array of ints
527+
Array of insertion points with the same shape as `value`.
528+
529+
See Also
530+
--------
531+
numpy.searchsorted : Similar method from NumPy.
558532
"""
559533
if sorter is not None:
560534
sorter = ensure_platform_int(sorter)
561535

562536
if is_integer_dtype(arr) and (
563537
is_integer(value) or is_integer_dtype(value)):
564-
return searchsorted_integer(arr, value, side=side, sorter=sorter)
565-
if not (is_object_dtype(arr) or is_numeric_dtype(arr) or
566-
is_categorical_dtype(arr)):
538+
from .arrays.array_ import array
539+
# if `arr` and `value` have different dtypes, `arr` would be
540+
# recast by numpy, causing a slow search.
541+
# Before searching below, we therefore try to give `value` the
542+
# same dtype as `arr`, while guarding against integer overflows.
543+
iinfo = np.iinfo(arr.dtype.type)
544+
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
545+
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
546+
# value within bounds, so no overflow, so can convert value dtype
547+
# to dtype of arr
548+
dtype = arr.dtype
549+
else:
550+
dtype = value_arr.dtype
551+
552+
if is_scalar(value):
553+
value = dtype.type(value)
554+
else:
555+
value = array(value, dtype=dtype)
556+
elif not (is_object_dtype(arr) or is_numeric_dtype(arr) or
557+
is_categorical_dtype(arr)):
558+
from pandas.core.series import Series
567559
# E.g. if `arr` is an array with dtype='datetime64[ns]'
568560
# and `value` is a pd.Timestamp, we may need to convert value
569-
from pandas.core.series import Series
570-
value = Series(value)._values
571-
return arr.searchsorted(value, side=side, sorter=sorter)
561+
value_ser = Series(value)._values
562+
value = value_ser[0] if is_scalar(value) else value_ser
563+
564+
result = arr.searchsorted(value, side=side, sorter=sorter)
565+
return result

pandas/core/series.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,12 +2388,8 @@ def __rmatmul__(self, other):
23882388
@Substitution(klass='Series')
23892389
@Appender(base._shared_docs['searchsorted'])
23902390
def searchsorted(self, value, side='left', sorter=None):
2391-
result = com.searchsorted(self._values, value,
2392-
side=side, sorter=sorter)
2393-
2394-
if is_scalar(value):
2395-
return result if is_scalar(result) else result[0]
2396-
return result
2391+
return com.searchsorted(self._values, value,
2392+
side=side, sorter=sorter)
23972393

23982394
# -------------------------------------------------------------------
23992395
# Combination

pandas/tests/arrays/test_array.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pandas as pd
1111
from pandas.api.extensions import register_extension_dtype
12+
from pandas.api.types import is_scalar
1213
from pandas.core.arrays import PandasArray, integer_array, period_array
1314
from pandas.tests.extension.decimal import (
1415
DecimalArray, DecimalDtype, to_decimal)
@@ -254,3 +255,45 @@ def test_array_not_registered(registry_without_decimal):
254255
result = pd.array(data, dtype=DecimalDtype)
255256
expected = DecimalArray._from_sequence(data)
256257
tm.assert_equal(result, expected)
258+
259+
260+
class TestArrayAnalytics(object):
261+
def test_searchsorted(self, string_dtype):
262+
arr = pd.array(['a', 'b', 'c'], dtype=string_dtype)
263+
264+
result = arr.searchsorted('a', side='left')
265+
assert is_scalar(result)
266+
assert result == 0
267+
268+
result = arr.searchsorted('a', side='right')
269+
assert is_scalar(result)
270+
assert result == 1
271+
272+
def test_searchsorted_numeric_dtypes_scalar(self, any_real_dtype):
273+
arr = pd.array([1, 3, 90], dtype=any_real_dtype)
274+
result = arr.searchsorted(30)
275+
assert is_scalar(result)
276+
assert result == 2
277+
278+
result = arr.searchsorted([30])
279+
expected = np.array([2], dtype=np.intp)
280+
tm.assert_numpy_array_equal(result, expected)
281+
282+
def test_searchsorted_numeric_dtypes_vector(self, any_real_dtype):
283+
arr = pd.array([1, 3, 90], dtype=any_real_dtype)
284+
result = arr.searchsorted([2, 30])
285+
expected = np.array([1, 2], dtype=np.intp)
286+
tm.assert_numpy_array_equal(result, expected)
287+
288+
def test_search_sorted_datetime64_scalar(self):
289+
arr = pd.array(pd.date_range('20120101', periods=10, freq='2D'))
290+
val = pd.Timestamp('20120102')
291+
result = arr.searchsorted(val)
292+
assert is_scalar(result)
293+
assert result == 1
294+
295+
def test_searchsorted_sorter(self, any_real_dtype):
296+
arr = pd.array([3, 1, 2], dtype=any_real_dtype)
297+
result = arr.searchsorted([0, 3], sorter=np.argsort(arr))
298+
expected = np.array([0, 2], dtype=np.intp)
299+
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)