-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add case_when method #56059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add case_when method #56059
Changes from all commits
f48502f
40057c7
4a8be16
089bbe6
bcfd458
b95ce55
acc3fdb
8be4349
8d08458
ec18086
0b72fbb
0085956
a441481
29ad697
bf740f9
264a675
2a3035e
5e33304
5c7c287
8569cd1
bbb5887
e03e3dc
283488f
7a8694c
f6cf725
67dfcaa
3da7cf2
bdc54f6
649fb84
b68d20e
b4de208
5966bfe
3e404fa
21659bc
f6d8cd0
becc626
918a19e
5744df2
bc6ba0e
a0f4797
cb7d6e3
c8f0e2e
9679b9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,9 @@ | |
from pandas.core.dtypes.astype import astype_is_view | ||
from pandas.core.dtypes.cast import ( | ||
LossySetitemError, | ||
construct_1d_arraylike_from_scalar, | ||
find_common_type, | ||
infer_dtype_from, | ||
maybe_box_native, | ||
maybe_cast_pointwise_result, | ||
) | ||
|
@@ -84,7 +87,10 @@ | |
CategoricalDtype, | ||
ExtensionDtype, | ||
) | ||
from pandas.core.dtypes.generic import ABCDataFrame | ||
from pandas.core.dtypes.generic import ( | ||
ABCDataFrame, | ||
ABCSeries, | ||
) | ||
from pandas.core.dtypes.inference import is_hashable | ||
from pandas.core.dtypes.missing import ( | ||
isna, | ||
|
@@ -113,6 +119,7 @@ | |
from pandas.core.arrays.sparse import SparseAccessor | ||
from pandas.core.arrays.string_ import StringDtype | ||
from pandas.core.construction import ( | ||
array as pd_array, | ||
extract_array, | ||
sanitize_array, | ||
) | ||
|
@@ -5629,6 +5636,121 @@ def between( | |
|
||
return lmask & rmask | ||
|
||
def case_when( | ||
self, | ||
caselist: list[ | ||
tuple[ | ||
ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]], | ||
ArrayLike | Scalar | Callable[[Series], Series | np.ndarray], | ||
], | ||
], | ||
) -> Series: | ||
""" | ||
Replace values where the conditions are True. | ||
|
||
Parameters | ||
---------- | ||
caselist : A list of tuples of conditions and expected replacements | ||
Takes the form: ``(condition0, replacement0)``, | ||
``(condition1, replacement1)``, ... . | ||
``condition`` should be a 1-D boolean array-like object | ||
or a callable. If ``condition`` is a callable, | ||
it is computed on the Series | ||
and should return a boolean Series or array. | ||
samukweku marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The callable must not change the input Series | ||
(though pandas doesn`t check it). ``replacement`` should be a | ||
1-D array-like object, a scalar or a callable. | ||
If ``replacement`` is a callable, it is computed on the Series | ||
and should return a scalar or Series. The callable | ||
must not change the input Series | ||
(though pandas doesn`t check it). | ||
|
||
.. versionadded:: 2.2.0 | ||
|
||
Returns | ||
------- | ||
Series | ||
|
||
See Also | ||
-------- | ||
Series.mask : Replace values where the condition is True. | ||
|
||
Examples | ||
-------- | ||
>>> c = pd.Series([6, 7, 8, 9], name='c') | ||
>>> a = pd.Series([0, 0, 1, 2]) | ||
>>> b = pd.Series([0, 3, 4, 5]) | ||
|
||
>>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement | ||
... (b.gt(0), b)]) | ||
0 6 | ||
1 3 | ||
2 1 | ||
3 2 | ||
Name: c, dtype: int64 | ||
""" | ||
if not isinstance(caselist, list): | ||
raise TypeError( | ||
f"The caselist argument should be a list; instead got {type(caselist)}" | ||
) | ||
|
||
if not caselist: | ||
raise ValueError( | ||
"provide at least one boolean condition, " | ||
"with a corresponding replacement." | ||
) | ||
|
||
for num, entry in enumerate(caselist): | ||
if not isinstance(entry, tuple): | ||
raise TypeError( | ||
f"Argument {num} must be a tuple; instead got {type(entry)}." | ||
) | ||
if len(entry) != 2: | ||
raise ValueError( | ||
f"Argument {num} must have length 2; " | ||
"a condition and replacement; " | ||
f"instead got length {len(entry)}." | ||
) | ||
caselist = [ | ||
( | ||
com.apply_if_callable(condition, self), | ||
com.apply_if_callable(replacement, self), | ||
) | ||
for condition, replacement in caselist | ||
] | ||
default = self.copy() | ||
conditions, replacements = zip(*caselist) | ||
common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be a set in the first place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if len(set(common_dtypes)) > 1: | ||
common_dtype = find_common_type(common_dtypes) | ||
updated_replacements = [] | ||
for condition, replacement in zip(conditions, replacements): | ||
if is_scalar(replacement): | ||
replacement = construct_1d_arraylike_from_scalar( | ||
value=replacement, length=len(condition), dtype=common_dtype | ||
) | ||
elif isinstance(replacement, ABCSeries): | ||
replacement = replacement.astype(common_dtype) | ||
else: | ||
replacement = pd_array(replacement, dtype=common_dtype) | ||
updated_replacements.append(replacement) | ||
replacements = updated_replacements | ||
default = default.astype(common_dtype) | ||
|
||
counter = reversed(range(len(conditions))) | ||
for position, condition, replacement in zip( | ||
counter, conditions[::-1], replacements[::-1] | ||
): | ||
try: | ||
default = default.mask( | ||
condition, other=replacement, axis=0, inplace=False, level=None | ||
) | ||
except Exception as error: | ||
raise ValueError( | ||
f"Failed to apply condition{position} and replacement{position}." | ||
) from error | ||
return default | ||
|
||
# error: Cannot determine type of 'isna' | ||
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] | ||
def isna(self) -> Series: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pandas import ( | ||
DataFrame, | ||
Series, | ||
array as pd_array, | ||
date_range, | ||
) | ||
import pandas._testing as tm | ||
|
||
|
||
@pytest.fixture | ||
def df(): | ||
""" | ||
base dataframe for testing | ||
""" | ||
return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you in-line this dataframe where used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is used in a number of places(about 7); inlining it feels repetitive. is there a reason to change to inlining? |
||
|
||
|
||
def test_case_when_caselist_is_not_a_list(df): | ||
""" | ||
Raise ValueError if caselist is not a list. | ||
""" | ||
msg = "The caselist argument should be a list; " | ||
msg += "instead got.+" | ||
with pytest.raises(TypeError, match=msg): # GH39154 | ||
df["a"].case_when(caselist=()) | ||
|
||
|
||
def test_case_when_no_caselist(df): | ||
""" | ||
Raise ValueError if no caselist is provided. | ||
""" | ||
msg = "provide at least one boolean condition, " | ||
msg += "with a corresponding replacement." | ||
with pytest.raises(ValueError, match=msg): # GH39154 | ||
df["a"].case_when([]) | ||
|
||
|
||
def test_case_when_odd_caselist(df): | ||
""" | ||
Raise ValueError if no of caselist is odd. | ||
""" | ||
msg = "Argument 0 must have length 2; " | ||
msg += "a condition and replacement; instead got length 3." | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df["a"].case_when([(df["a"].eq(1), 1, df.a.gt(1))]) | ||
|
||
|
||
def test_case_when_raise_error_from_mask(df): | ||
""" | ||
Raise Error from within Series.mask | ||
""" | ||
msg = "Failed to apply condition0 and replacement0." | ||
with pytest.raises(ValueError, match=msg): | ||
df["a"].case_when([(df["a"].eq(1), [1, 2])]) | ||
|
||
|
||
def test_case_when_single_condition(df): | ||
""" | ||
Test output on a single condition. | ||
""" | ||
result = Series([np.nan, np.nan, np.nan]).case_when([(df.a.eq(1), 1)]) | ||
expected = Series([1, np.nan, np.nan]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_multiple_conditions(df): | ||
""" | ||
Test output when booleans are derived from a computation | ||
""" | ||
result = Series([np.nan, np.nan, np.nan]).case_when( | ||
[(df.a.eq(1), 1), (Series([False, True, False]), 2)] | ||
) | ||
expected = Series([1, 2, np.nan]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_multiple_conditions_replacement_list(df): | ||
""" | ||
Test output when replacement is a list | ||
""" | ||
result = Series([np.nan, np.nan, np.nan]).case_when( | ||
[([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])] | ||
) | ||
expected = Series([1, 2, np.nan]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_multiple_conditions_replacement_extension_dtype(df): | ||
""" | ||
Test output when replacement has an extension dtype | ||
""" | ||
result = Series([np.nan, np.nan, np.nan]).case_when( | ||
[ | ||
([True, False, False], 1), | ||
(df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), | ||
], | ||
) | ||
expected = Series([1, 2, np.nan], dtype="Float64") | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_multiple_conditions_replacement_series(df): | ||
""" | ||
Test output when replacement is a Series | ||
""" | ||
result = Series([np.nan, np.nan, np.nan]).case_when( | ||
[ | ||
(np.array([True, False, False]), 1), | ||
(df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), | ||
], | ||
) | ||
expected = Series([1, 2, np.nan]) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_non_range_index(): | ||
""" | ||
Test output if index is not RangeIndex | ||
""" | ||
rng = np.random.default_rng(seed=123) | ||
dates = date_range("1/1/2000", periods=8) | ||
df = DataFrame( | ||
rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] | ||
) | ||
result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)]) | ||
expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_case_when_callable(): | ||
""" | ||
Test output on a callable | ||
""" | ||
# https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html | ||
x = np.linspace(-2.5, 2.5, 6) | ||
ser = Series(x) | ||
result = ser.case_when( | ||
caselist=[ | ||
(lambda df: df < 0, lambda df: -df), | ||
(lambda df: df >= 0, lambda df: df), | ||
] | ||
) | ||
expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x]) | ||
tm.assert_series_equal(result, Series(expected)) |
Uh oh!
There was an error while loading. Please reload this page.