@@ -944,27 +944,29 @@ def select_n_frame(frame, columns, n, method, keep):
944
944
-------
945
945
nordered : DataFrame
946
946
"""
947
+ from pandas .core .series import Series
947
948
if not is_list_like (columns ):
948
949
columns = [columns ]
949
950
else :
950
951
columns = list (columns )
951
- reverse = method == 'nlargest'
952
+
953
+ ascending = method == 'nsmallest'
954
+ tmp = Series (frame .index )
955
+ frame .reset_index (inplace = True , drop = True )
956
+ cur_frame = frame
957
+
952
958
for i , column in enumerate (columns ):
953
- series = frame [column ]
954
- if reverse :
955
- inds = series .argsort ()[::- 1 ][:n ]
956
- else :
957
- inds = series .argsort ()[:n ]
958
- values = series .iloc [inds ]
959
- if i != len (columns ) - 1 and values .duplicated ().any ():
960
- # This series has duplicate values => we must consider all rows in
961
- # frame that match `values`
962
- # The first condition is for the last column. In this case we don't
963
- # care if there are duplicates => no need to do the check
964
- frame = frame [series .isin (values )]
959
+ series = cur_frame [column ]
960
+ values = getattr (series , method )(n , keep = keep )
961
+ if i + 1 == len (columns ) or values .is_unique :
962
+ cur_frame = cur_frame .reindex (values .index )
965
963
else :
966
- break
967
- return frame .take (inds )
964
+ cur_frame = cur_frame [series .isin (values )].sort_values (
965
+ column , ascending = ascending
966
+ )
967
+ frame .index = tmp .ix [frame .index ]
968
+ cur_frame .index = tmp .ix [cur_frame .index ]
969
+ return cur_frame
968
970
969
971
970
972
def _finalize_nsmallest (arr , kth_val , n , keep , narr ):
0 commit comments