diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 023b360c5bb5c..78d73f3f93c4a 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -944,27 +944,29 @@ def select_n_frame(frame, columns, n, method, keep): ------- nordered : DataFrame """ + from pandas.core.series import Series if not is_list_like(columns): columns = [columns] else: columns = list(columns) - reverse = method == 'nlargest' + + ascending = method == 'nsmallest' + tmp = Series(frame.index) + frame.reset_index(inplace=True, drop=True) + cur_frame = frame + for i, column in enumerate(columns): - series = frame[column] - if reverse: - inds = series.argsort()[::-1][:n] - else: - inds = series.argsort()[:n] - values = series.iloc[inds] - if i != len(columns) - 1 and values.duplicated().any(): - # This series has duplicate values => we must consider all rows in - # frame that match `values` - # The first condition is for the last column. In this case we don't - # care if there are duplicates => no need to do the check - frame = frame[series.isin(values)] + series = cur_frame[column] + values = getattr(series, method)(n, keep=keep) + if i + 1 == len(columns) or values.is_unique: + cur_frame = cur_frame.reindex(values.index) else: - break - return frame.take(inds) + cur_frame = cur_frame[series.isin(values)].sort_values( + column, ascending=ascending + ) + frame.index = tmp.ix[frame.index] + cur_frame.index = tmp.ix[cur_frame.index] + return cur_frame def _finalize_nsmallest(arr, kth_val, n, keep, narr):