From 83f13655777000fcbeb2462de00c3b16ab8ff450 Mon Sep 17 00:00:00 2001 From: Ethan Chen Date: Sun, 13 Sep 2020 20:13:59 -0400 Subject: [PATCH 1/5] BUG: DataFrameGroupBy.transform with axis=1 fails (#36308) --- pandas/core/groupby/generic.py | 8 +++++++- pandas/core/groupby/groupby.py | 4 ++-- pandas/tests/frame/apply/test_frame_transform.py | 2 -- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index ffd756bed43b6..f03c2824a2a26 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1676,7 +1676,13 @@ def _wrap_transformed_output( columns.name = self.obj.columns.name result = self.obj._constructor(indexed_output) - result.columns = columns + + if self.axis == 1: + result = result.T + result.columns = self.obj.columns + else: + result.columns = columns + result.index = self.obj.index return result diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 30bd53a3ddff1..e6a7567240141 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2668,8 +2668,8 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None, axis=0 fill_method = "pad" limit = 0 filled = getattr(self, fill_method)(limit=limit) - fill_grp = filled.groupby(self.grouper.codes) - shifted = fill_grp.shift(periods=periods, freq=freq) + fill_grp = filled.groupby(self.grouper.codes, axis=self.axis) + shifted = fill_grp.shift(periods=periods, freq=freq, axis=self.axis) return (filled / shifted) - 1 @Substitution(name="groupby") diff --git a/pandas/tests/frame/apply/test_frame_transform.py b/pandas/tests/frame/apply/test_frame_transform.py index 346e60954fc13..c0b85561464c1 100644 --- a/pandas/tests/frame/apply/test_frame_transform.py +++ b/pandas/tests/frame/apply/test_frame_transform.py @@ -27,8 +27,6 @@ def test_transform_groupby_kernel(axis, float_frame, op): pytest.xfail("DataFrame.cumcount does not exist") if op == "tshift": pytest.xfail("Only works on time index and is deprecated") - if axis == 1 or axis == "columns": - pytest.xfail("GH 36308: groupby.transform with axis=1 is broken") args = [0.0] if op == "fillna" else [] if axis == 0 or axis == "index": From 643d4cf8d42eba74be1577cb11d92877df018b92 Mon Sep 17 00:00:00 2001 From: rhshadrach Date: Fri, 13 Nov 2020 21:55:12 -0500 Subject: [PATCH 2/5] Minor cleanup; fixed cumcount; added tests --- pandas/core/groupby/generic.py | 5 ++--- pandas/core/groupby/groupby.py | 2 +- .../tests/groupby/transform/test_transform.py | 20 ++++++++++++++++++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index e4dc7cea3e08f..3395b9d36fd0c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1675,15 +1675,14 @@ def _wrap_transformed_output( DataFrame """ indexed_output = {key.position: val for key, val in output.items()} - columns = Index(key.label for key in output) - columns.name = self.obj.columns.name - result = self.obj._constructor(indexed_output) if self.axis == 1: result = result.T result.columns = self.obj.columns else: + columns = Index(key.label for key in output) + columns.name = self.obj.columns.name result.columns = columns result.index = self.obj.index diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index a64c79781aac5..ec96a0d502d3f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2365,7 +2365,7 @@ def cumcount(self, ascending: bool = True): dtype: int64 """ with group_selection_context(self): - index = self._selected_obj.index + index = self._selected_obj._get_axis(self.axis) cumcounts = self._cumcount_array(ascending=ascending) return self._obj_1d_constructor(cumcounts, index) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index d7426a5e3b42e..eb00b8e7d7a0d 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -158,7 +158,25 @@ def test_transform_broadcast(tsframe, ts): assert_fp_equal(res.xs(idx), agged[idx]) -def test_transform_axis(tsframe): +def test_transform_axis_1(transformation_func): + # GH 36308 + if transformation_func == "tshift": + pytest.xfail("tshift is deprecated") + args = ("ffill",) if transformation_func == "fillna" else tuple() + + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args) + expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T + + if transformation_func == "diff": + # Result contains nans, so transpose coerces to float + expected["b"] = expected["b"].astype(int) + + # cumcount returns Series; the rest are DataFrame + tm.assert_equal(result, expected) + + +def test_transform_axis_ts(tsframe): # make sure that we are setting the axes # correctly when on axis=0 or 1 From 64e7d9e08c69367016c6705128a7871b1038d781 Mon Sep 17 00:00:00 2001 From: rhshadrach Date: Fri, 13 Nov 2020 22:24:07 -0500 Subject: [PATCH 3/5] whatsnew --- doc/source/whatsnew/v1.2.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 09cb024cbd95c..1aa88e56689c2 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -552,7 +552,7 @@ Groupby/resample/rolling - Bug in :meth:`Rolling.median` and :meth:`Rolling.quantile` returned wrong values for :class:`BaseIndexer` subclasses with non-monotonic starting or ending points for windows (:issue:`37153`) - Bug in :meth:`DataFrame.groupby` dropped ``nan`` groups from result with ``dropna=False`` when grouping over a single column (:issue:`35646`, :issue:`35542`) - Bug in :meth:`DataFrameGroupBy.head`, :meth:`DataFrameGroupBy.tail`, :meth:`SeriesGroupBy.head`, and :meth:`SeriesGroupBy.tail` would raise when used with ``axis=1`` (:issue:`9772`) - +- Bug in :meth:`DataFrameGroupBy.transform` would raise when used with ``axis=1`` and a transformation kernel (e.g. "shift") (:issue:`36308`) Reshaping ^^^^^^^^^ From 47b4865c1ae58b8c0b48bccc04c49df3f03943d3 Mon Sep 17 00:00:00 2001 From: rhshadrach Date: Fri, 13 Nov 2020 23:23:18 -0500 Subject: [PATCH 4/5] int32 fix --- pandas/tests/groupby/transform/test_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index eb00b8e7d7a0d..afe746e76904f 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -170,7 +170,7 @@ def test_transform_axis_1(transformation_func): if transformation_func == "diff": # Result contains nans, so transpose coerces to float - expected["b"] = expected["b"].astype(int) + expected["b"] = expected["b"].astype(np.intp) # cumcount returns Series; the rest are DataFrame tm.assert_equal(result, expected) From 0fc43f6c64075e53739e877b77cc1d763274d916 Mon Sep 17 00:00:00 2001 From: rhshadrach Date: Fri, 13 Nov 2020 23:51:10 -0500 Subject: [PATCH 5/5] int32 fix --- pandas/tests/groupby/transform/test_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index afe746e76904f..b4e023f569844 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -170,7 +170,7 @@ def test_transform_axis_1(transformation_func): if transformation_func == "diff": # Result contains nans, so transpose coerces to float - expected["b"] = expected["b"].astype(np.intp) + expected["b"] = expected["b"].astype("int64") # cumcount returns Series; the rest are DataFrame tm.assert_equal(result, expected)