@@ -776,7 +776,15 @@ def test_nsmallest():
776
776
777
777
@pytest .mark .parametrize (
778
778
"data, groups" ,
779
- [([0 , 1 , 2 , 3 ], [0 , 0 , 1 , 1 ]), ([0 ], [0 ])],
779
+ [
780
+ ([0 , 1 , 2 , 3 ], [0 , 0 , 1 , 1 ]),
781
+ ([0 ], [0 ]),
782
+ * [
783
+ (np .array (data , dtype = dtyp ), np .array (groups , dtype = dtyp ))
784
+ for data , groups in [([0 , 1 , 2 , 3 ], [0 , 0 , 1 , 1 ]), ([0 ], [0 ])]
785
+ for dtyp in tm .ALL_INT_NUMPY_DTYPES
786
+ ],
787
+ ],
780
788
)
781
789
@pytest .mark .parametrize ("method" , ["nlargest" , "nsmallest" ])
782
790
def test_nlargest_and_smallest_noop (data , groups , method ):
@@ -786,8 +794,9 @@ def test_nlargest_and_smallest_noop(data, groups, method):
786
794
if method == "nlargest" :
787
795
data = list (reversed (data ))
788
796
ser = Series (data , name = "a" )
789
- result = getattr (ser .groupby (np .array (groups , dtype = np .int64 )), method )(n = 2 )
790
- expected = Series (data , index = MultiIndex .from_arrays ([groups , ser .index ]), name = "a" )
797
+ result = getattr (ser .groupby (groups ), method )(n = 2 )
798
+ expidx = np .array (groups , dtype = np .intp ) if isinstance (groups , list ) else groups
799
+ expected = Series (data , index = MultiIndex .from_arrays ([expidx , ser .index ]), name = "a" )
791
800
tm .assert_series_equal (result , expected )
792
801
793
802
0 commit comments