Skip to content

Commit 9b1dce2

Browse files
committed
float16 & groupby.nlargest/nsmallest tests
1 parent 9d8822d commit 9b1dce2

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

pandas/tests/groupby/test_function.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,15 @@ def test_nsmallest():
776776

777777
@pytest.mark.parametrize(
778778
"data, groups",
779-
[([0, 1, 2, 3], [0, 0, 1, 1]), ([0], [0])],
779+
[
780+
([0, 1, 2, 3], [0, 0, 1, 1]),
781+
([0], [0]),
782+
*[
783+
(np.array(data, dtype=dtyp), np.array(groups, dtype=dtyp))
784+
for data, groups in [([0, 1, 2, 3], [0, 0, 1, 1]), ([0], [0])]
785+
for dtyp in tm.ALL_INT_NUMPY_DTYPES
786+
],
787+
],
780788
)
781789
@pytest.mark.parametrize("method", ["nlargest", "nsmallest"])
782790
def test_nlargest_and_smallest_noop(data, groups, method):
@@ -786,8 +794,9 @@ def test_nlargest_and_smallest_noop(data, groups, method):
786794
if method == "nlargest":
787795
data = list(reversed(data))
788796
ser = Series(data, name="a")
789-
result = getattr(ser.groupby(np.array(groups, dtype=np.int64)), method)(n=2)
790-
expected = Series(data, index=MultiIndex.from_arrays([groups, ser.index]), name="a")
797+
result = getattr(ser.groupby(groups), method)(n=2)
798+
expidx = np.array(groups, dtype=np.intp) if isinstance(groups, list) else groups
799+
expected = Series(data, index=MultiIndex.from_arrays([expidx, ser.index]), name="a")
791800
tm.assert_series_equal(result, expected)
792801

793802

pandas/tests/indexes/test_numpy_compat.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ def test_numpy_ufuncs_basic(index, func):
8686
if is_complex_dtype(index):
8787
assert result.dtype == "complex64"
8888
elif index.dtype in ["bool", "int8", "uint8"]:
89-
# TODO: find out when index.dtype is float16 vs float32 (platform issue)
9089
assert result.dtype in ["float16", "float32"]
91-
elif index.dtype in ["float16"]:
92-
assert result.dtype == "float16"
9390
elif index.dtype in ["int16", "uint16", "float32"]:
9491
assert result.dtype == "float32"
9592
else:

0 commit comments

Comments
 (0)