From 3df69b524c9dbe8a64b0ff40191f1849bb233d3e Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 11 Feb 2020 23:01:36 -0800 Subject: [PATCH 1/2] CLN: Some groupby internals --- pandas/core/groupby/ops.py | 12 ++++++------ pandas/tests/groupby/test_apply.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 4e593ce543ea6..79c2a0e309144 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -169,7 +169,7 @@ def apply(self, f, data: FrameOrSeries, axis: int = 0): and not sdata.index._has_complex_internals ): try: - result_values, mutated = splitter.fast_apply(f, group_keys) + result_values, mutated = splitter.fast_apply(f, sdata, group_keys) except libreduction.InvalidApply as err: # This Exception is raised if `f` triggers an exception @@ -927,11 +927,9 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series: class FrameSplitter(DataSplitter): - def fast_apply(self, f, names): + def fast_apply(self, f, sdata, names): # must return keys::list, values::list, mutated::bool starts, ends = lib.generate_slices(self.slabels, self.ngroups) - - sdata = self._get_sorted_data() return libreduction.apply_frame_axis0(sdata, f, names, starts, ends) def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: @@ -941,11 +939,13 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: return sdata.iloc[:, slice_obj] -def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter: +def get_splitter( + data: FrameOrSeries, labels, ngroups: int, axis: int = 0 +) -> DataSplitter: if isinstance(data, Series): klass: Type[DataSplitter] = SeriesSplitter else: # i.e. DataFrame klass = FrameSplitter - return klass(data, *args, **kwargs) + return klass(data, labels, ngroups, axis) diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index 41ec70468aaeb..18ad5d90b3f60 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -108,8 +108,9 @@ def f(g): splitter = grouper._get_splitter(g._selected_obj, axis=g.axis) group_keys = grouper._get_group_keys() + sdata = splitter._get_sorted_data() - values, mutated = splitter.fast_apply(f, group_keys) + values, mutated = splitter.fast_apply(f, sdata, group_keys) assert not mutated From cbf4d7591e3308868005ae950074a0a26e09e235 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Wed, 12 Feb 2020 09:03:46 -0800 Subject: [PATCH 2/2] Additional annotation --- pandas/core/groupby/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d9b5def58c495..7259268ac3f2b 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -925,7 +925,7 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series: class FrameSplitter(DataSplitter): - def fast_apply(self, f, sdata, names): + def fast_apply(self, f, sdata: FrameOrSeries, names): # must return keys::list, values::list, mutated::bool starts, ends = lib.generate_slices(self.slabels, self.ngroups) return libreduction.apply_frame_axis0(sdata, f, names, starts, ends) @@ -938,7 +938,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: def get_splitter( - data: FrameOrSeries, labels, ngroups: int, axis: int = 0 + data: FrameOrSeries, labels: np.ndarray, ngroups: int, axis: int = 0 ) -> DataSplitter: if isinstance(data, Series): klass: Type[DataSplitter] = SeriesSplitter