From b1d2e307f825b71501daaa95fd7e8f05ca9b4dc2 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 24 Jul 2022 20:12:54 -0400 Subject: [PATCH 1/3] Fix groupby.apply() and sum(axis=1) --- pandas-stubs/core/frame.pyi | 12 +++++++++++- pandas-stubs/core/groupby/generic.pyi | 12 +++++++++++- pandas-stubs/core/groupby/groupby.pyi | 1 - tests/test_frame.py | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 54ddd319e..7958b2097 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1902,6 +1902,16 @@ class DataFrame(NDFrame, OpsMixin): fill_value: float | None = ..., ) -> DataFrame: ... @overload + def sum( + self, + axis: Literal[1, "columns"], + skipna: _bool | None = ..., + level: None = ..., + numeric_only: _bool | None = ..., + min_count: int = ..., + **kwargs, + ) -> Series: ... + @overload def sum( self, axis: AxisType | None = ..., @@ -1909,7 +1919,7 @@ class DataFrame(NDFrame, OpsMixin): numeric_only: _bool | None = ..., min_count: int = ..., *, - level: Level, + level: Level = ..., **kwargs, ) -> DataFrame: ... @overload diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index d923e32aa..501d658ad 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from typing import ( Any, Callable, + Iterable, Iterator, Literal, NamedTuple, @@ -111,7 +112,16 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy): class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... - def apply(self, func, *args, **kwargs) -> DataFrame: ... + @overload + def apply( + self, func: Callable[[DataFrame], DataFrame], *args, **kwargs + ) -> DataFrame: ... + @overload + def apply(self, func: Callable[[DataFrame], Series], *args, **kwargs) -> Series: ... + @overload + def apply( + self, func: Callable[[Iterable], float], *args, **kwargs + ) -> DataFrame: ... @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 7bb5560ce..37e8b955b 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -65,7 +65,6 @@ class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin): def pipe(self, func: Callable, *args, **kwargs): ... plot = ... def get_group(self, name, obj: DataFrame | None = ...) -> DataFrame: ... - def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ... class GroupBy(BaseGroupBy[NDFrameT]): def count(self) -> FrameOrSeriesUnion: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index c8a615ee9..928bfed65 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1339,3 +1339,18 @@ def test_setitem_list(): iter2: Iterator[tuple[str, int]] = (v for v in lst4) check(assert_type(df.set_index(iter1), pd.DataFrame), pd.DataFrame) check(assert_type(df.set_index(iter2), pd.DataFrame), pd.DataFrame) + + +def test_groupby_apply() -> None: + # GH 167 + df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) + + def summean(x: pd.DataFrame) -> pd.Series: + return x.sum().mean() + + check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series) + + check( + assert_type(df.groupby("col1").apply(lambda x: x.sum().mean()), pd.Series), + pd.Series, + ) From 2a6b37ec74c61e18e7371e10a6d4792df3b8adb3 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 25 Jul 2022 08:39:39 -0400 Subject: [PATCH 2/3] reorder sum overloads, change apply overloads --- pandas-stubs/core/frame.pyi | 14 ++------------ pandas-stubs/core/groupby/generic.pyi | 14 +++++--------- tests/test_frame.py | 6 ++++-- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 7958b2097..20b4273a9 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1904,7 +1904,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def sum( self, - axis: Literal[1, "columns"], + axis: AxisType | None = ..., skipna: _bool | None = ..., level: None = ..., numeric_only: _bool | None = ..., @@ -1919,19 +1919,9 @@ class DataFrame(NDFrame, OpsMixin): numeric_only: _bool | None = ..., min_count: int = ..., *, - level: Level = ..., + level: Level, **kwargs, ) -> DataFrame: ... - @overload - def sum( - self, - axis: AxisType | None = ..., - skipna: _bool | None = ..., - level: None = ..., - numeric_only: _bool | None = ..., - min_count: int = ..., - **kwargs, - ) -> Series: ... def swapaxes( self, axis1: AxisType, axis2: AxisType, copy: _bool = ... ) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 501d658ad..9a2fea027 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,7 +3,6 @@ from __future__ import annotations from typing import ( Any, Callable, - Iterable, Iterator, Literal, NamedTuple, @@ -112,16 +111,13 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy): class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... + # mypy sees the two overloads as overlapping @overload - def apply( - self, func: Callable[[DataFrame], DataFrame], *args, **kwargs - ) -> DataFrame: ... - @overload - def apply(self, func: Callable[[DataFrame], Series], *args, **kwargs) -> Series: ... + def apply( # type: ignore[misc] + self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs + ) -> Series: ... @overload - def apply( - self, func: Callable[[Iterable], float], *args, **kwargs - ) -> DataFrame: ... + def apply(self, func: Callable, *args, **kwargs) -> DataFrame: ... @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 928bfed65..27ced9263 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -7,6 +7,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Hashable, Iterable, Iterator, @@ -1345,12 +1346,13 @@ def test_groupby_apply() -> None: # GH 167 df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) - def summean(x: pd.DataFrame) -> pd.Series: + def summean(x: pd.DataFrame) -> float: return x.sum().mean() check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series) + lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean() check( - assert_type(df.groupby("col1").apply(lambda x: x.sum().mean()), pd.Series), + assert_type(df.groupby("col1").apply(lfunc), pd.Series), pd.Series, ) From 999f147c50d6c424ed889bbc957bc70e1da91f44 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 25 Jul 2022 10:03:33 -0400 Subject: [PATCH 3/3] make the catchall return a Union --- pandas-stubs/core/groupby/generic.pyi | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 9a2fea027..31617650a 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from typing import ( Any, Callable, + Iterable, Iterator, Literal, NamedTuple, @@ -117,7 +118,11 @@ class DataFrameGroupBy(GroupBy): self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs ) -> Series: ... @overload - def apply(self, func: Callable, *args, **kwargs) -> DataFrame: ... + def apply( # type: ignore[misc] + self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs + ) -> DataFrame: ... + @overload + def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ... @overload def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ... @overload