diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 5ef1f9dea5091..88524b1f458ff 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -907,6 +907,7 @@ Groupby/resample/rolling to the input DataFrame is inconsistent. An internal heuristic to detect index mutation would behave differently for equal but not identical indices. In particular, the result index shape might change if a copy of the input would be returned. The behaviour now is consistent, independent of internal heuristics. (:issue:`31612`, :issue:`14927`, :issue:`13056`) +- Bug in :meth:`SeriesGroupBy.agg` where any column name was accepted in the named aggregation of ``SeriesGroupBy`` previously. The behaviour now allows only ``str`` and callables else would raise ``TypeError``. (:issue:`34422`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/aggregation.py b/pandas/core/aggregation.py index 6130e05b2a4dc..838722f60b380 100644 --- a/pandas/core/aggregation.py +++ b/pandas/core/aggregation.py @@ -5,7 +5,7 @@ from collections import defaultdict from functools import partial -from typing import Any, DefaultDict, List, Sequence, Tuple +from typing import Any, Callable, DefaultDict, List, Sequence, Tuple, Union from pandas.core.dtypes.common import is_dict_like, is_list_like @@ -196,3 +196,39 @@ def maybe_mangle_lambdas(agg_spec: Any) -> Any: mangled_aggspec = _managle_lambda_list(agg_spec) return mangled_aggspec + + +def validate_func_kwargs( + kwargs: dict, +) -> Tuple[List[str], List[Union[str, Callable[..., Any]]]]: + """ + Validates types of user-provided "named aggregation" kwargs. + `TypeError` is raised if aggfunc is not `str` or callable. + + Parameters + ---------- + kwargs : dict + + Returns + ------- + columns : List[str] + List of user-provied keys. + func : List[Union[str, callable[...,Any]]] + List of user-provided aggfuncs + + Examples + -------- + >>> validate_func_kwargs({'one': 'min', 'two': 'max'}) + (['one', 'two'], ['min', 'max']) + """ + no_arg_message = "Must provide 'func' or named aggregation **kwargs." + tuple_given_message = "func is expected but recieved {} in **kwargs." + columns = list(kwargs) + func = [] + for col_func in kwargs.values(): + if not (isinstance(col_func, str) or callable(col_func)): + raise TypeError(tuple_given_message.format(type(col_func).__name__)) + func.append(col_func) + if not columns: + raise TypeError(no_arg_message) + return columns, func diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index ea4b6f4e65341..d589b0e0fe83c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -57,6 +57,7 @@ is_multi_agg_with_relabel, maybe_mangle_lambdas, normalize_keyword_aggregation, + validate_func_kwargs, ) import pandas.core.algorithms as algorithms from pandas.core.base import DataError, SpecificationError @@ -233,13 +234,9 @@ def aggregate( relabeling = func is None columns = None - no_arg_message = "Must provide 'func' or named aggregation **kwargs." if relabeling: - columns = list(kwargs) - func = [kwargs[col] for col in columns] + columns, func = validate_func_kwargs(kwargs) kwargs = {} - if not columns: - raise TypeError(no_arg_message) if isinstance(func, str): return getattr(self, func)(*args, **kwargs) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index d4b061594c364..371ec11cdba77 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -511,6 +511,21 @@ def test_mangled(self): expected = pd.DataFrame({"a": [0, 0], "b": [1, 1]}) tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize( + "inp", + [ + pd.NamedAgg(column="anything", aggfunc="min"), + ("anything", "min"), + ["anything", "min"], + ], + ) + def test_named_agg_nametuple(self, inp): + # GH34422 + s = pd.Series([1, 1, 2, 2, 3, 3, 4, 5]) + msg = f"func is expected but recieved {type(inp).__name__}" + with pytest.raises(TypeError, match=msg): + s.groupby(s.values).agg(a=inp) + class TestNamedAggregationDataFrame: def test_agg_relabel(self):