Skip to content

Commit efbf499

Browse files
authored
PERF: GroupBy.count (#43730)
1 parent cd13e3a commit efbf499

File tree

4 files changed

+49
-18
lines changed

4 files changed

+49
-18
lines changed

pandas/core/groupby/generic.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Mapping,
2222
TypeVar,
2323
Union,
24+
cast,
2425
)
2526
import warnings
2627

@@ -30,7 +31,9 @@
3031
from pandas._typing import (
3132
ArrayLike,
3233
FrameOrSeries,
34+
Manager,
3335
Manager2D,
36+
SingleManager,
3437
)
3538
from pandas.util._decorators import (
3639
Appender,
@@ -80,7 +83,6 @@
8083
Index,
8184
MultiIndex,
8285
all_indexes_same,
83-
default_index,
8486
)
8587
from pandas.core.series import Series
8688
from pandas.core.util.numba_ import maybe_use_numba
@@ -159,19 +161,21 @@ def pinner(cls):
159161
class SeriesGroupBy(GroupBy[Series]):
160162
_apply_allowlist = base.series_apply_allowlist
161163

162-
def _wrap_agged_manager(self, mgr: Manager2D) -> Series:
163-
single = mgr.iget(0)
164+
def _wrap_agged_manager(self, mgr: Manager) -> Series:
165+
if mgr.ndim == 1:
166+
mgr = cast(SingleManager, mgr)
167+
single = mgr
168+
else:
169+
mgr = cast(Manager2D, mgr)
170+
single = mgr.iget(0)
164171
ser = self.obj._constructor(single, name=self.obj.name)
165172
# NB: caller is responsible for setting ser.index
166173
return ser
167174

168-
def _get_data_to_aggregate(self) -> Manager2D:
175+
def _get_data_to_aggregate(self) -> SingleManager:
169176
ser = self._obj_with_exclusions
170177
single = ser._mgr
171-
columns = default_index(1)
172-
# Much faster than using ser.to_frame() since we avoid inferring columns
173-
# from scalar
174-
return single.to_2d_mgr(columns)
178+
return single
175179

176180
def _iterate_slices(self) -> Iterable[Series]:
177181
yield self._selected_obj

pandas/core/groupby/groupby.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,8 @@ def count(self) -> Series | DataFrame:
17451745
ids, _, ngroups = self.grouper.group_info
17461746
mask = ids != -1
17471747

1748+
is_series = data.ndim == 1
1749+
17481750
def hfunc(bvalues: ArrayLike) -> ArrayLike:
17491751
# TODO(2DEA): reshape would not be necessary with 2D EAs
17501752
if bvalues.ndim == 1:
@@ -1754,6 +1756,10 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
17541756
masked = mask & ~isna(bvalues)
17551757

17561758
counted = lib.count_level_2d(masked, labels=ids, max_bin=ngroups, axis=1)
1759+
if is_series:
1760+
assert counted.ndim == 2
1761+
assert counted.shape[0] == 1
1762+
return counted[0]
17571763
return counted
17581764

17591765
new_mgr = data.grouped_reduce(hfunc)
@@ -2702,7 +2708,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
27022708
mgr = self._get_data_to_aggregate()
27032709

27042710
res_mgr = mgr.grouped_reduce(blk_func, ignore_failures=True)
2705-
if len(res_mgr.items) != len(mgr.items):
2711+
if not is_ser and len(res_mgr.items) != len(mgr.items):
27062712
warnings.warn(
27072713
"Dropping invalid columns in "
27082714
f"{type(self).__name__}.quantile is deprecated. "
@@ -3134,14 +3140,15 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31343140
obj = self._obj_with_exclusions
31353141

31363142
# Operate block-wise instead of column-by-column
3137-
orig_ndim = obj.ndim
3143+
is_ser = obj.ndim == 1
31383144
mgr = self._get_data_to_aggregate()
31393145

31403146
if numeric_only:
31413147
mgr = mgr.get_numeric_data()
31423148

31433149
res_mgr = mgr.grouped_reduce(blk_func, ignore_failures=True)
3144-
if len(res_mgr.items) != len(mgr.items):
3150+
3151+
if not is_ser and len(res_mgr.items) != len(mgr.items):
31453152
howstr = how.replace("group_", "")
31463153
warnings.warn(
31473154
"Dropping invalid columns in "
@@ -3162,7 +3169,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31623169
# We should never get here
31633170
raise TypeError("All columns were dropped in grouped_reduce")
31643171

3165-
if orig_ndim == 1:
3172+
if is_ser:
31663173
out = self._wrap_agged_manager(res_mgr)
31673174
out.index = self.grouper.result_index
31683175
else:

pandas/core/internals/base.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from pandas._typing import (
13+
ArrayLike,
1314
DtypeObj,
1415
Shape,
1516
)
@@ -18,7 +19,10 @@
1819
from pandas.core.dtypes.cast import find_common_type
1920

2021
from pandas.core.base import PandasObject
21-
from pandas.core.indexes.api import Index
22+
from pandas.core.indexes.api import (
23+
Index,
24+
default_index,
25+
)
2226

2327
T = TypeVar("T", bound="DataManager")
2428

@@ -171,6 +175,23 @@ def setitem_inplace(self, indexer, value) -> None:
171175
"""
172176
self.array[indexer] = value
173177

178+
def grouped_reduce(self, func, ignore_failures: bool = False):
179+
"""
180+
ignore_failures : bool, default False
181+
Not used; for compatibility with ArrayManager/BlockManager.
182+
"""
183+
184+
arr = self.array
185+
res = func(arr)
186+
index = default_index(len(res))
187+
188+
mgr = type(self).from_array(res, index)
189+
return mgr
190+
191+
@classmethod
192+
def from_array(cls, arr: ArrayLike, index: Index):
193+
raise AbstractMethodError(cls)
194+
174195

175196
def interleaved_dtype(dtypes: list[DtypeObj]) -> DtypeObj | None:
176197
"""

pandas/tests/resample/test_resample_api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,10 @@ def test_agg():
350350
expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1)
351351
expected.columns = pd.MultiIndex.from_product([["A", "B"], ["mean", "std"]])
352352
for t in cases:
353-
warn = FutureWarning if t in cases[1:3] else None
354-
with tm.assert_produces_warning(
355-
warn, match="Dropping invalid columns", check_stacklevel=False
356-
):
357-
# .var on dt64 column raises and is dropped
353+
with tm.assert_produces_warning(None):
354+
# .var on dt64 column raises and is dropped, but the path in core.apply
355+
# that it goes through will still suppress a TypeError even
356+
# once the deprecations in the groupby code are enforced
358357
result = t.aggregate([np.mean, np.std])
359358
tm.assert_frame_equal(result, expected)
360359

0 commit comments

Comments
 (0)