Skip to content

Commit bab4f30

Browse files
authored
REF: simplify python_agg_general (#51447)
1 parent 4965c51 commit bab4f30

File tree

2 files changed

+53
-43
lines changed

2 files changed

+53
-43
lines changed

pandas/core/groupby/generic.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,28 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
250250
if cyfunc and not args and not kwargs:
251251
return getattr(self, cyfunc)()
252252

253+
if self.ngroups == 0:
254+
# e.g. test_evaluate_with_empty_groups without any groups to
255+
# iterate over, we have no output on which to do dtype
256+
# inference. We default to using the existing dtype.
257+
# xref GH#51445
258+
obj = self._obj_with_exclusions
259+
return self.obj._constructor(
260+
[],
261+
name=self.obj.name,
262+
index=self.grouper.result_index,
263+
dtype=obj.dtype,
264+
)
265+
253266
if self.grouper.nkeys > 1:
254267
return self._python_agg_general(func, *args, **kwargs)
255268

256269
try:
257270
return self._python_agg_general(func, *args, **kwargs)
258271
except KeyError:
259-
# TODO: KeyError is raised in _python_agg_general,
260-
# see test_groupby.test_basic
272+
# KeyError raised in test_groupby.test_basic is bc the func does
273+
# a dictionary lookup on group.name, but group name is not
274+
# pinned in _python_agg_general, only in _aggregate_named
261275
result = self._aggregate_named(func, *args, **kwargs)
262276

263277
# result is a dict whose keys are the elements of result_index
@@ -267,6 +281,15 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
267281

268282
agg = aggregate
269283

284+
def _python_agg_general(self, func, *args, **kwargs):
285+
func = com.is_builtin_func(func)
286+
f = lambda x: func(x, *args, **kwargs)
287+
288+
obj = self._obj_with_exclusions
289+
result = self.grouper.agg_series(obj, f)
290+
res = obj._constructor(result, name=obj.name)
291+
return self._wrap_aggregated_output(res)
292+
270293
def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame:
271294
if isinstance(arg, dict):
272295
if self.as_index:
@@ -308,18 +331,6 @@ def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame:
308331
output = self._reindex_output(output)
309332
return output
310333

311-
def _indexed_output_to_ndframe(
312-
self, output: Mapping[base.OutputKey, ArrayLike]
313-
) -> Series:
314-
"""
315-
Wrap the dict result of a GroupBy aggregation into a Series.
316-
"""
317-
assert len(output) == 1
318-
values = next(iter(output.values()))
319-
result = self.obj._constructor(values)
320-
result.name = self.obj.name
321-
return result
322-
323334
def _wrap_applied_output(
324335
self,
325336
data: Series,
@@ -1319,6 +1330,31 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
13191330

13201331
agg = aggregate
13211332

1333+
def _python_agg_general(self, func, *args, **kwargs):
1334+
func = com.is_builtin_func(func)
1335+
f = lambda x: func(x, *args, **kwargs)
1336+
1337+
# iterate through "columns" ex exclusions to populate output dict
1338+
output: dict[base.OutputKey, ArrayLike] = {}
1339+
1340+
if self.ngroups == 0:
1341+
# e.g. test_evaluate_with_empty_groups different path gets different
1342+
# result dtype in empty case.
1343+
return self._python_apply_general(f, self._selected_obj, is_agg=True)
1344+
1345+
for idx, obj in enumerate(self._iterate_slices()):
1346+
name = obj.name
1347+
result = self.grouper.agg_series(obj, f)
1348+
key = base.OutputKey(label=name, position=idx)
1349+
output[key] = result
1350+
1351+
if not output:
1352+
# e.g. test_margins_no_values_no_cols
1353+
return self._python_apply_general(f, self._selected_obj)
1354+
1355+
res = self._indexed_output_to_ndframe(output)
1356+
return self._wrap_aggregated_output(res)
1357+
13221358
def _iterate_slices(self) -> Iterable[Series]:
13231359
obj = self._selected_obj
13241360
if self.axis == 1:
@@ -1885,7 +1921,9 @@ def nunique(self, dropna: bool = True) -> DataFrame:
18851921

18861922
if self.axis != 0:
18871923
# see test_groupby_crash_on_nunique
1888-
return self._python_agg_general(lambda sgb: sgb.nunique(dropna))
1924+
return self._python_apply_general(
1925+
lambda sgb: sgb.nunique(dropna), self._obj_with_exclusions, is_agg=True
1926+
)
18891927

18901928
obj = self._obj_with_exclusions
18911929
results = self._apply_to_column_groupbys(

pandas/core/groupby/groupby.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,34 +1409,6 @@ def _python_apply_general(
14091409
is_transform,
14101410
)
14111411

1412-
# TODO: I (jbrockmendel) think this should be equivalent to doing grouped_reduce
1413-
# on _agg_py_fallback, but trying that here fails a bunch of tests 2023-02-07.
1414-
@final
1415-
def _python_agg_general(self, func, *args, **kwargs):
1416-
func = com.is_builtin_func(func)
1417-
f = lambda x: func(x, *args, **kwargs)
1418-
1419-
# iterate through "columns" ex exclusions to populate output dict
1420-
output: dict[base.OutputKey, ArrayLike] = {}
1421-
1422-
if self.ngroups == 0:
1423-
# e.g. test_evaluate_with_empty_groups different path gets different
1424-
# result dtype in empty case.
1425-
return self._python_apply_general(f, self._selected_obj, is_agg=True)
1426-
1427-
for idx, obj in enumerate(self._iterate_slices()):
1428-
name = obj.name
1429-
result = self.grouper.agg_series(obj, f)
1430-
key = base.OutputKey(label=name, position=idx)
1431-
output[key] = result
1432-
1433-
if not output:
1434-
# e.g. test_groupby_crash_on_nunique, test_margins_no_values_no_cols
1435-
return self._python_apply_general(f, self._selected_obj)
1436-
1437-
res = self._indexed_output_to_ndframe(output)
1438-
return self._wrap_aggregated_output(res)
1439-
14401412
@final
14411413
def _agg_general(
14421414
self,

0 commit comments

Comments
 (0)