Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def _gotitem(self, key, ndim, subset=None):
# require postprocessing of the result by transform.
cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])

cython_cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])

# List of aggregation/reduction functions.
# These map each group to a single numeric value
reduction_kernels = frozenset(
Expand Down
41 changes: 29 additions & 12 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
is_timedelta64_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCCategoricalIndex
from pandas.core.dtypes.missing import (
isna,
Expand Down Expand Up @@ -96,6 +95,10 @@ class WrappedCythonOp:
Dispatch logic for functions defined in _libs.groupby
"""

# Functions for which we do _not_ attempt to cast the cython result
# back to the original dtype.
cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])

def __init__(self, kind: str, how: str):
self.kind = kind
self.how = how
Expand Down Expand Up @@ -568,11 +571,13 @@ def _ea_wrap_cython_operation(
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
values = values.view("M8[ns]")
npvalues = values.view("M8[ns]")
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
kind, npvalues, how, axis, min_count, **kwargs
)
if how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
# preserve float64 dtype
return res_values

Expand All @@ -586,21 +591,33 @@ def _ea_wrap_cython_operation(
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
if isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)
if how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

return res_values
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
return cls._from_sequence(res_values, dtype=dtype)

elif is_float_dtype(values.dtype):
# FloatingArray
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = type(orig_values)._from_sequence(res_values)
return result
if how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = maybe_cast_result_dtype(orig_values.dtype, how)
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
# has no attribute "construct_array_type"
cls = dtype.construct_array_type() # type: ignore[union-attr]
return cls._from_sequence(res_values, dtype=dtype)

raise NotImplementedError(
f"function is not implemented for this dtype: {values.dtype}"
Expand Down Expand Up @@ -713,9 +730,9 @@ def _cython_operation(

result = result.T

if how not in base.cython_cast_blocklist:
if how not in cy_op.cast_blocklist:
# e.g. if we are int64 and need to restore to datetime64/timedelta64
# "rank" is the only member of cython_cast_blocklist we get here
# "rank" is the only member of cast_blocklist we get here
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
op_result = maybe_downcast_to_dtype(result, dtype)
else:
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/groupby/test_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def test_ngroup_respects_groupby_order(self):
[
[Timestamp(f"2016-05-{i:02d} 20:09:25+00:00") for i in range(1, 4)],
[Timestamp(f"2016-05-{i:02d} 20:09:25") for i in range(1, 4)],
[Timestamp(f"2016-05-{i:02d} 20:09:25", tz="UTC") for i in range(1, 4)],
[Timedelta(x, unit="h") for x in range(1, 4)],
[Period(freq="2W", year=2017, month=x) for x in range(1, 4)],
],
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,17 @@ def test_idxmin_idxmax_returns_int_types(func, values):
df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific")
df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0]
df["c_period"] = df["c_date"].dt.to_period("W")
df["c_Integer"] = df["c_int"].astype("Int64")
df["c_Floating"] = df["c_float"].astype("Float64")

result = getattr(df.groupby("name"), func)()

expected = DataFrame(values, index=Index(["A", "B"], name="name"))
expected["c_date_tz"] = expected["c_date"]
expected["c_timedelta"] = expected["c_date"]
expected["c_period"] = expected["c_date"]
expected["c_Integer"] = expected["c_int"]
expected["c_Floating"] = expected["c_float"]

tm.assert_frame_equal(result, expected)

Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,8 @@ def test_pivot_table_values_key_error():
[to_datetime(0)],
[date_range(0, 1, 1, tz="US/Eastern")],
[pd.array([0], dtype="Int64")],
[pd.array([0], dtype="Float64")],
[pd.array([False], dtype="boolean")],
],
)
@pytest.mark.parametrize("method", ["attr", "agg", "apply"])
Expand Down
13 changes: 12 additions & 1 deletion pandas/tests/groupby/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,19 @@ def test_rank_resets_each_group(pct, exp):
tm.assert_frame_equal(result, exp_df)


def test_rank_avg_even_vals():
@pytest.mark.parametrize(
"dtype", ["int64", "int32", "uint64", "uint32", "float64", "float32"]
)
@pytest.mark.parametrize("upper", [True, False])
def test_rank_avg_even_vals(dtype, upper):
if upper:
# use IntegerDtype/FloatingDtype
dtype = dtype[0].upper() + dtype[1:]
dtype = dtype.replace("Ui", "UI")
df = DataFrame({"key": ["a"] * 4, "val": [1] * 4})
df["val"] = df["val"].astype(dtype)
assert df["val"].dtype == dtype

result = df.groupby("key").rank()
exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"])
tm.assert_frame_equal(result, exp_df)
Expand Down