Skip to content

Commit bd6f248

Browse files
committed
Simplify implementation
1 parent da44399 commit bd6f248

File tree

1 file changed

+54
-55
lines changed

1 file changed

+54
-55
lines changed

pandas/core/common.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
from pandas._libs import lib, tslibs
1616
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
1717
from pandas import compat
18-
from pandas.compat import iteritems, PY2, PY36, OrderedDict
18+
from pandas.compat import iteritems, PY36, OrderedDict
1919
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
20-
from pandas.core.dtypes.common import (is_integer, is_integer_dtype,
21-
is_bool_dtype, is_extension_array_dtype,
22-
is_array_like,
23-
is_float_dtype, is_object_dtype,
24-
is_categorical_dtype, is_numeric_dtype,
25-
is_scalar, ensure_platform_int)
20+
from pandas.core.dtypes.common import (
21+
is_integer, is_integer_dtype, is_bool_dtype,
22+
is_extension_array_dtype, is_array_like, is_object_dtype,
23+
is_categorical_dtype, is_numeric_dtype, is_scalar, ensure_platform_int)
2624
from pandas.core.dtypes.inference import _iterable_not_string
2725
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2826

@@ -487,58 +485,47 @@ def f(x):
487485
return f
488486

489487

490-
def ensure_integer_dtype(arr, value):
488+
def searchsorted_integer(arr, value, side="left", sorter=None):
491489
"""
492-
Ensure optimal dtype for :func:`searchsorted_integer` is returned.
490+
searchsorted implementation for searching integer arrays.
491+
492+
We get a speedup if we ensure the dtype of arr and value are the same
493+
(if possible) before searchingm as numpy implicitly converts the dtypes
494+
if they're different, which would cause a slowdown.
495+
496+
See :func:`searchsorted` for a more general searchsorted implementation.
493497
494498
Parameters
495499
----------
496-
arr : a numpy integer array
497-
value : a number or array of numbers
500+
arr : numpy.array
501+
a numpy array of integers
502+
value : int or numpy.array
503+
an integer or an array of integers that we want to find the
504+
location(s) for in `arr`
505+
side : str
506+
One of {'left', 'right'}
507+
sorter : numpy.array, optional
498508
499509
Returns
500510
-------
501-
dtype : an numpy integer dtype
502-
503-
Raises
504-
------
505-
TypeError : if value is not a number
506-
"""
507-
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
508-
509-
if PY2 and not is_numeric_dtype(value_arr):
510-
# python 2 allows "a" < 1, avoid such nonsense
511-
msg = "value must be numeric, was type {}"
512-
raise TypeError(msg.format(value))
513-
514-
iinfo = np.iinfo(arr.dtype)
515-
if not ((value_arr < iinfo.min).any() or (value_arr > iinfo.max).any()):
516-
return arr.dtype
517-
else:
518-
return value_arr.dtype
519-
520-
521-
def searchsorted_integer(arr, value, side="left", sorter=None):
522-
"""
523-
searchsorted implementation, but only for integer arrays.
524-
525-
We get a speedup if the dtype of arr and value is the same.
526-
527-
See :func:`searchsorted` for a more general searchsorted implementation.
511+
int or numpy.array
512+
The locations(s) of `value` in `arr`.
528513
"""
529514
if sorter is not None:
530515
sorter = ensure_platform_int(sorter)
531516

532-
dtype = ensure_integer_dtype(arr, value)
533-
534-
if is_integer(value) or is_integer_dtype(value):
535-
value = np.asarray(value, dtype=dtype)
536-
elif hasattr(value, 'is_integer') and value.is_integer():
537-
# float 2.0 can be converted to int 2 for better speed,
538-
# but float 2.2 should *not* be converted to int 2
539-
value = np.asarray(value, dtype=dtype)
517+
# below we try to give `value` the same dtype as `arr`, while guarding
518+
# against integer overflows. If the value of `value` is outside of the
519+
# bound of `arr`, `arr` would be recast by numpy, causing a slower search.
520+
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
521+
iinfo = np.iinfo(arr.dtype)
522+
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
523+
dtype = arr.dtype
524+
else:
525+
dtype = value_arr.dtype
526+
value = np.asarray(value, dtype=dtype)
540527

541-
return np.searchsorted(arr, value, side=side, sorter=sorter)
528+
return arr.searchsorted(value, side=side, sorter=sorter)
542529

543530

544531
def searchsorted(arr, value, side="left", sorter=None):
@@ -550,18 +537,30 @@ def searchsorted(arr, value, side="left", sorter=None):
550537
the order of `arr` would be preserved.
551538
552539
See :class:`IndexOpsMixin.searchsorted` for more details and examples.
540+
541+
Parameters
542+
----------
543+
arr : numpy.array or ExtensionArray
544+
value : scalar or numpy.array
545+
side : str
546+
One of {'left', 'right'}
547+
sorter : numpy.array, optional
548+
549+
Returns
550+
-------
551+
int or numpy.array
552+
The locations(s) of `value` in `arr`.
553553
"""
554554
if sorter is not None:
555555
sorter = ensure_platform_int(sorter)
556556

557-
if is_integer_dtype(arr):
557+
if is_integer_dtype(arr) and (
558+
is_integer(value) or is_integer_dtype(value)):
558559
return searchsorted_integer(arr, value, side=side, sorter=sorter)
559-
elif (is_object_dtype(arr) or is_float_dtype(arr) or
560-
is_categorical_dtype(arr)):
561-
return arr.searchsorted(value, side=side, sorter=sorter)
562-
else:
563-
# fallback solution. E.g. arr is an array with dtype='datetime64[ns]'
564-
# and value is a pd.Timestamp, need to convert value
560+
if not (is_object_dtype(arr) or is_numeric_dtype(arr) or
561+
is_categorical_dtype(arr)):
562+
# E.g. if `arr` is an array with dtype='datetime64[ns]'
563+
# and `value` is a pd.Timestamp, we may need to convert value
565564
from pandas.core.series import Series
566565
value = Series(value)._values
567-
return arr.searchsorted(value, side=side, sorter=sorter)
566+
return arr.searchsorted(value, side=side, sorter=sorter)

0 commit comments

Comments
 (0)