1111from pandas .core .dtypes .common import (
1212 ensure_int64 , ensure_platform_int , is_categorical_dtype , is_list_like )
1313from pandas .core .dtypes .missing import isna
14+ from pandas .core .dtypes .generic import ABCExtensionArray
1415
1516import pandas .core .algorithms as algorithms
1617
@@ -404,7 +405,8 @@ def _reorder_by_uniques(uniques, labels):
404405 return uniques , labels
405406
406407
407- def safe_sort (values , labels = None , na_sentinel = - 1 , assume_unique = False ):
408+ def safe_sort (values , labels = None , na_sentinel = - 1 , assume_unique = False ,
409+ check_outofbounds = True ):
408410 """
409411 Sort ``values`` and reorder corresponding ``labels``.
410412 ``values`` should be unique if ``labels`` is not None.
@@ -425,6 +427,10 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
425427 assume_unique : bool, default False
426428 When True, ``values`` are assumed to be unique, which can speed up
427429 the calculation. Ignored when ``labels`` is None.
430+ check_outofbounds : bool, default True
431+ Check if labels are out of bound for the values and put out of bound
432+ labels equal to na_sentinel. If ``check_outofbounds=False``, it is
433+ assumed there are no out of bound labels.
428434
429435 Returns
430436 -------
@@ -446,8 +452,8 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
446452 raise TypeError ("Only list-like objects are allowed to be passed to"
447453 "safe_sort as values" )
448454
449- if not isinstance (values , np .ndarray ):
450-
455+ if ( not isinstance (values , np .ndarray )
456+ and not isinstance ( values , ABCExtensionArray )):
451457 # don't convert to string types
452458 dtype , _ = infer_dtype_from_array (values )
453459 values = np .asarray (values , dtype = dtype )
@@ -461,7 +467,8 @@ def sort_mixed(values):
461467 return np .concatenate ([nums , np .asarray (strs , dtype = object )])
462468
463469 sorter = None
464- if PY3 and lib .infer_dtype (values , skipna = False ) == 'mixed-integer' :
470+ if (PY3 and not isinstance (values , ABCExtensionArray )
471+ and lib .infer_dtype (values , skipna = False ) == 'mixed-integer' ):
465472 # unorderable in py3 if mixed str/int
466473 ordered = sort_mixed (values )
467474 else :
@@ -494,15 +501,26 @@ def sort_mixed(values):
494501 t .map_locations (values )
495502 sorter = ensure_platform_int (t .lookup (ordered ))
496503
497- reverse_indexer = np .empty (len (sorter ), dtype = np .int_ )
498- reverse_indexer .put (sorter , np .arange (len (sorter )))
499-
500- mask = (labels < - len (values )) | (labels >= len (values )) | \
501- (labels == na_sentinel )
502-
503- # (Out of bound indices will be masked with `na_sentinel` next, so we may
504- # deal with them here without performance loss using `mode='wrap'`.)
505- new_labels = reverse_indexer .take (labels , mode = 'wrap' )
506- np .putmask (new_labels , mask , na_sentinel )
504+ if na_sentinel == - 1 :
505+ # take_1d is faster, but only works for na_sentinels of -1
506+ order2 = sorter .argsort ()
507+ new_labels = algorithms .take_1d (order2 , labels , fill_value = - 1 )
508+ if check_outofbounds :
509+ mask = (labels < - len (values )) | (labels >= len (values ))
510+ else :
511+ mask = None
512+ else :
513+ reverse_indexer = np .empty (len (sorter ), dtype = np .int_ )
514+ reverse_indexer .put (sorter , np .arange (len (sorter )))
515+ # Out of bound indices will be masked with `na_sentinel` next, so we
516+ # may deal with them here without performance loss using `mode='wrap'`
517+ new_labels = reverse_indexer .take (labels , mode = 'wrap' )
518+
519+ mask = labels == na_sentinel
520+ if check_outofbounds :
521+ mask = mask | (labels < - len (values )) | (labels >= len (values ))
522+
523+ if mask is not None :
524+ np .putmask (new_labels , mask , na_sentinel )
507525
508526 return ordered , ensure_platform_int (new_labels )
0 commit comments