diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 54ddd319e..20b4273a9 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1906,22 +1906,22 @@ class DataFrame(NDFrame, OpsMixin): self, axis: AxisType | None = ..., skipna: _bool | None = ..., + level: None = ..., numeric_only: _bool | None = ..., min_count: int = ..., - *, - level: Level, **kwargs, - ) -> DataFrame: ... + ) -> Series: ... @overload def sum( self, axis: AxisType | None = ..., skipna: _bool | None = ..., - level: None = ..., numeric_only: _bool | None = ..., min_count: int = ..., + *, + level: Level, **kwargs, - ) -> Series: ... + ) -> DataFrame: ... def swapaxes( self, axis1: AxisType, axis2: AxisType, copy: _bool = ... ) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index d923e32aa..31617650a 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from typing import ( Any, Callable, + Iterable, Iterator, Literal, NamedTuple, @@ -111,7 +112,17 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy): class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... - def apply(self, func, *args, **kwargs) -> DataFrame: ... + # mypy sees the two overloads as overlapping + @overload + def apply( # type: ignore[misc] + self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs + ) -> Series: ... + @overload + def apply( # type: ignore[misc] + self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs + ) -> DataFrame: ... + @overload + def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ... @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 7bb5560ce..37e8b955b 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -65,7 +65,6 @@ class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin): def pipe(self, func: Callable, *args, **kwargs): ... plot = ... def get_group(self, name, obj: DataFrame | None = ...) -> DataFrame: ... - def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ... class GroupBy(BaseGroupBy[NDFrameT]): def count(self) -> FrameOrSeriesUnion: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index c8a615ee9..27ced9263 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -7,6 +7,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Hashable, Iterable, Iterator, @@ -1339,3 +1340,19 @@ def test_setitem_list(): iter2: Iterator[tuple[str, int]] = (v for v in lst4) check(assert_type(df.set_index(iter1), pd.DataFrame), pd.DataFrame) check(assert_type(df.set_index(iter2), pd.DataFrame), pd.DataFrame) + + +def test_groupby_apply() -> None: + # GH 167 + df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) + + def summean(x: pd.DataFrame) -> float: + return x.sum().mean() + + check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series) + + lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean() + check( + assert_type(df.groupby("col1").apply(lfunc), pd.Series), + pd.Series, + )