Skip to content

Commit 02e1748

Browse files
More specific types for GroupBy.apply. (#177)
* More specific typing for DataFrameGroupBy.apply. * Add missing SeriesGroupBy.sum * Reorder apply overloads.
1 parent c915a8b commit 02e1748

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class SeriesGroupBy(GroupBy):
100100
def nlargest(self, n: int = ..., keep: str = ...) -> Series[S1]: ...
101101
def nsmallest(self, n: int = ..., keep: str = ...) -> Series[S1]: ...
102102
def nth(self, n: int | Sequence[int], dropna: str | None = ...) -> Series[S1]: ...
103+
def sum(self, **kwargs) -> Series[S1]: ...
103104

104105
class _DataFrameGroupByScalar(DataFrameGroupBy):
105106
def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ...
@@ -110,17 +111,19 @@ class _DataFrameGroupByNonScalar(DataFrameGroupBy):
110111
class DataFrameGroupBy(GroupBy):
111112
def any(self, skipna: bool = ...) -> DataFrame: ...
112113
def all(self, skipna: bool = ...) -> DataFrame: ...
113-
# mypy sees the two overloads as overlapping
114+
# mypy and pyright see these overloads as overlapping
114115
@overload
115116
def apply( # type: ignore[misc]
116-
self, func: Callable[[DataFrame], Series | Scalar], *args, **kwargs
117+
self, func: Callable[[DataFrame], Scalar | list | dict], *args, **kwargs
117118
) -> Series: ...
118119
@overload
119120
def apply( # type: ignore[misc]
120-
self, func: Callable[[Iterable], Series | Scalar], *args, **kwargs
121+
self, func: Callable[[DataFrame], Series | DataFrame], *args, **kwargs
121122
) -> DataFrame: ...
122123
@overload
123-
def apply(self, func: Callable, *args, **kwargs) -> DataFrame | Series: ...
124+
def apply( # type: ignore[misc]
125+
self, func: Callable[[Iterable], float], *args, **kwargs
126+
) -> DataFrame: ...
124127
@overload
125128
def aggregate(self, arg: str, *args, **kwargs) -> DataFrame: ...
126129
@overload

tests/test_frame.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,13 +1346,32 @@ def test_groupby_apply() -> None:
13461346
# GH 167
13471347
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
13481348

1349-
def summean(x: pd.DataFrame) -> float:
1349+
def sum_mean(x: pd.DataFrame) -> float:
13501350
return x.sum().mean()
13511351

1352-
check(assert_type(df.groupby("col1").apply(summean), pd.Series), pd.Series)
1352+
check(assert_type(df.groupby("col1").apply(sum_mean), pd.Series), pd.Series)
13531353

13541354
lfunc: Callable[[pd.DataFrame], float] = lambda x: x.sum().mean()
13551355
check(
13561356
assert_type(df.groupby("col1").apply(lfunc), pd.Series),
13571357
pd.Series,
13581358
)
1359+
1360+
def sum_to_list(x: pd.DataFrame) -> list:
1361+
return x.sum().tolist()
1362+
1363+
check(assert_type(df.groupby("col1").apply(sum_to_list), pd.Series), pd.Series)
1364+
1365+
def sum_to_series(x: pd.DataFrame) -> pd.Series:
1366+
return x.sum()
1367+
1368+
check(
1369+
assert_type(df.groupby("col1").apply(sum_to_series), pd.DataFrame), pd.DataFrame
1370+
)
1371+
1372+
def sample_to_df(x: pd.DataFrame) -> pd.DataFrame:
1373+
return x.sample()
1374+
1375+
check(
1376+
assert_type(df.groupby("col1").apply(sample_to_df), pd.DataFrame), pd.DataFrame
1377+
)

0 commit comments

Comments
 (0)