diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 63087672d1365..7259268ac3f2b 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -169,7 +169,7 @@ def apply(self, f, data: FrameOrSeries, axis: int = 0): and not sdata.index._has_complex_internals ): try: - result_values, mutated = splitter.fast_apply(f, group_keys) + result_values, mutated = splitter.fast_apply(f, sdata, group_keys) except libreduction.InvalidApply as err: # This Exception is raised if `f` triggers an exception @@ -925,11 +925,9 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series: class FrameSplitter(DataSplitter): - def fast_apply(self, f, names): + def fast_apply(self, f, sdata: FrameOrSeries, names): # must return keys::list, values::list, mutated::bool starts, ends = lib.generate_slices(self.slabels, self.ngroups) - - sdata = self._get_sorted_data() return libreduction.apply_frame_axis0(sdata, f, names, starts, ends) def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: @@ -939,11 +937,13 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: return sdata.iloc[:, slice_obj] -def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter: +def get_splitter( + data: FrameOrSeries, labels: np.ndarray, ngroups: int, axis: int = 0 +) -> DataSplitter: if isinstance(data, Series): klass: Type[DataSplitter] = SeriesSplitter else: # i.e. DataFrame klass = FrameSplitter - return klass(data, *args, **kwargs) + return klass(data, labels, ngroups, axis) diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index 41ec70468aaeb..18ad5d90b3f60 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -108,8 +108,9 @@ def f(g): splitter = grouper._get_splitter(g._selected_obj, axis=g.axis) group_keys = grouper._get_group_keys() + sdata = splitter._get_sorted_data() - values, mutated = splitter.fast_apply(f, group_keys) + values, mutated = splitter.fast_apply(f, sdata, group_keys) assert not mutated