Skip to content

Commit d873a46

Browse files
add proper type when grouping by a Series (#708)
* remove Series from GroupByObjectNonScalar, add new orverloads for Series.groupby and DataFrame.groupby * Add tests for iteration over groupby * Add dtype in tests * address PR comments * Use bound method for ByT and SeriesByT typevars
1 parent bb41036 commit d873a46

File tree

5 files changed

+95
-24
lines changed

5 files changed

+95
-24
lines changed

pandas-stubs/_typing.pyi

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -407,33 +407,51 @@ Function: TypeAlias = np.ufunc | Callable[..., Any]
407407
_HashableTa = TypeVar("_HashableTa", bound=Hashable)
408408
ByT = TypeVar(
409409
"ByT",
410-
str,
411-
bytes,
412-
datetime.date,
413-
datetime.datetime,
414-
datetime.timedelta,
415-
np.datetime64,
416-
np.timedelta64,
417-
bool,
418-
int,
419-
float,
420-
complex,
421-
Timestamp,
422-
Timedelta,
423-
Scalar,
424-
Period,
425-
Interval[int],
426-
Interval[float],
427-
Interval[Timestamp],
428-
Interval[Timedelta],
429-
tuple,
410+
bound=str
411+
| bytes
412+
| datetime.date
413+
| datetime.datetime
414+
| datetime.timedelta
415+
| np.datetime64
416+
| np.timedelta64
417+
| bool
418+
| int
419+
| float
420+
| complex
421+
| Timestamp
422+
| Timedelta
423+
| Scalar
424+
| Period
425+
| Interval[int]
426+
| Interval[float]
427+
| Interval[Timestamp]
428+
| Interval[Timedelta]
429+
| tuple,
430+
)
431+
# Use a distinct SeriesByT when using groupby with Series of known dtype.
432+
# Essentially, an intersection between Series S1 TypeVar, and ByT TypeVar
433+
SeriesByT = TypeVar(
434+
"SeriesByT",
435+
bound=str
436+
| bytes
437+
| datetime.date
438+
| bool
439+
| int
440+
| float
441+
| complex
442+
| Timestamp
443+
| Timedelta
444+
| Period
445+
| Interval[int]
446+
| Interval[float]
447+
| Interval[Timestamp]
448+
| Interval[Timedelta],
430449
)
431450
GroupByObjectNonScalar: TypeAlias = (
432451
tuple
433452
| list[_HashableTa]
434453
| Function
435454
| list[Function]
436-
| Series
437455
| list[Series]
438456
| np.ndarray
439457
| list[np.ndarray]
@@ -443,7 +461,7 @@ GroupByObjectNonScalar: TypeAlias = (
443461
| Grouper
444462
| list[Grouper]
445463
)
446-
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar
464+
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series
447465

448466
StataDateFormat: TypeAlias = Literal[
449467
"tc",

pandas-stubs/core/frame.pyi

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ from pandas._typing import (
107107
ReplaceMethod,
108108
Scalar,
109109
ScalarT,
110+
SeriesByT,
110111
SortKind,
111112
StataDateFormat,
112113
StorageOptions,
@@ -1087,7 +1088,20 @@ class DataFrame(NDFrame, OpsMixin):
10871088
@overload
10881089
def groupby(
10891090
self,
1090-
by: CategoricalIndex | Index,
1091+
by: Series[SeriesByT],
1092+
axis: Axis = ...,
1093+
level: Level | None = ...,
1094+
as_index: _bool = ...,
1095+
sort: _bool = ...,
1096+
group_keys: _bool = ...,
1097+
squeeze: _bool = ...,
1098+
observed: _bool = ...,
1099+
dropna: _bool = ...,
1100+
) -> DataFrameGroupBy[SeriesByT]: ...
1101+
@overload
1102+
def groupby(
1103+
self,
1104+
by: CategoricalIndex | Index | Series,
10911105
axis: Axis = ...,
10921106
level: Level | None = ...,
10931107
as_index: _bool = ...,

pandas-stubs/core/series.pyi

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ from pandas._typing import (
130130
Renamer,
131131
ReplaceMethod,
132132
Scalar,
133+
SeriesByT,
133134
SortKind,
134135
StrDtypeArg,
135136
TimedeltaDtypeArg,
@@ -635,7 +636,20 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
635636
@overload
636637
def groupby(
637638
self,
638-
by: CategoricalIndex | Index,
639+
by: Series[SeriesByT],
640+
axis: AxisIndex = ...,
641+
level: Level | None = ...,
642+
as_index: _bool = ...,
643+
sort: _bool = ...,
644+
group_keys: _bool = ...,
645+
squeeze: _bool = ...,
646+
observed: _bool = ...,
647+
dropna: _bool = ...,
648+
) -> SeriesGroupBy[S1, SeriesByT]: ...
649+
@overload
650+
def groupby(
651+
self,
652+
by: CategoricalIndex | Index | Series,
639653
axis: AxisIndex = ...,
640654
level: Level | None = ...,
641655
as_index: _bool = ...,

tests/test_frame.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,20 @@ def test_types_groupby_any() -> None:
985985
)
986986

987987

988+
def test_types_groupby_iter() -> None:
989+
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
990+
series_groupby = pd.Series([True, True, False], dtype=bool)
991+
first_group = next(iter(df.groupby(series_groupby)))
992+
check(
993+
assert_type(first_group[0], bool),
994+
bool,
995+
)
996+
check(
997+
assert_type(first_group[1], pd.DataFrame),
998+
pd.DataFrame,
999+
)
1000+
1001+
9881002
def test_types_merge() -> None:
9891003
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
9901004
df2 = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [0, 1, 0]})

tests/test_series.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,17 @@ def test_types_group_by_with_dropna_keyword() -> None:
732732
s.groupby(level=0).sum()
733733

734734

735+
def test_types_groupby_iter() -> None:
736+
s = pd.Series([1, 1, 2], dtype=int)
737+
series_groupby = pd.Series([True, True, False], dtype=bool)
738+
first_group = next(iter(s.groupby(series_groupby)))
739+
check(
740+
assert_type(first_group[0], bool),
741+
bool,
742+
)
743+
check(assert_type(first_group[1], "pd.Series[int]"), pd.Series, np.integer)
744+
745+
735746
def test_types_plot() -> None:
736747
s = pd.Series([0, 1, 1, 0, -10])
737748
if TYPE_CHECKING: # skip pytest

0 commit comments

Comments
 (0)