Skip to content

Commit 3d783b7

Browse files
committed
case_when
1 parent 9c1c9a0 commit 3d783b7

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

pandas-stubs/core/series.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ from pandas._typing import (
220220
WriteBuffer,
221221
_T_co,
222222
np_1darray,
223+
np_1darray_bool,
223224
np_1darray_dt,
224225
np_1darray_int64,
225226
np_1darray_intp,
@@ -1678,11 +1679,12 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
16781679
) -> Series[S1]: ...
16791680
def case_when(
16801681
self,
1681-
caselist: list[
1682+
caselist: Sequence[
16821683
tuple[
16831684
Sequence[bool]
1685+
| np_1darray_bool
16841686
| Series[bool]
1685-
| Callable[[Series], Series | np_ndarray | Sequence[bool]],
1687+
| Callable[[Series], Sequence[bool] | np_1darray_bool | Series[bool]],
16861688
ListLikeU | Scalar | Callable[[Series], Series | np_ndarray],
16871689
],
16881690
],

tests/series/test_series.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import (
4+
Callable,
45
Hashable,
56
Iterable,
67
Iterator,
@@ -60,6 +61,7 @@
6061
check,
6162
ensure_clean,
6263
np_1darray,
64+
np_1darray_bool,
6365
np_ndarray_num,
6466
pytest_warns_bounded,
6567
)
@@ -3623,13 +3625,20 @@ def test_case_when() -> None:
36233625
c = pd.Series([6, 7, 8, 9], name="c")
36243626
a = pd.Series([0, 0, 1, 2])
36253627
b = pd.Series([0, 3, 4, 5])
3626-
r = c.case_when(
3627-
caselist=[
3628-
(a.gt(0), a),
3629-
(b.gt(0), b),
3630-
]
3631-
)
3632-
check(assert_type(r, pd.Series), pd.Series)
3628+
3629+
c0 = [(a.gt(0), a), (b.gt(0), b)]
3630+
check(assert_type(c.case_when(c0), pd.Series), pd.Series)
3631+
3632+
def foo_factory(
3633+
thresh: int,
3634+
) -> Callable[[pd.Series], pd.Series[bool] | np_1darray_bool]:
3635+
def foo(s: pd.Series) -> pd.Series[bool] | np_1darray_bool:
3636+
return s >= thresh
3637+
3638+
return foo
3639+
3640+
c1 = [(foo_factory(2), a), (foo_factory(0), b)]
3641+
check(assert_type(c.case_when(c1), pd.Series), pd.Series)
36333642

36343643

36353644
def test_series_unique_timestamp() -> None:

0 commit comments

Comments
 (0)