diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 16cfc3efc1024..9daa83cf6b85f 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -14,6 +14,25 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ +.. _whatsnew_220.enhancements.case_when: + +Create a pandas Series based on one or more conditions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) + +.. ipython:: python + + import pandas as pd + + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + pd.case_when( + (df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement + default = 'default', # optional + ) + + .. _whatsnew_220.enhancements.calamine: Calamine engine for :func:`read_excel` diff --git a/pandas/__init__.py b/pandas/__init__.py index 7fab662ed2de4..232acf3aa85fa 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -73,6 +73,7 @@ notnull, # indexes Index, + case_when, CategoricalIndex, RangeIndex, MultiIndex, @@ -252,6 +253,7 @@ __all__ = [ "ArrowDtype", "BooleanDtype", + "case_when", "Categorical", "CategoricalDtype", "CategoricalIndex", diff --git a/pandas/core/api.py b/pandas/core/api.py index 2cfe5ffc0170d..5c7271c155459 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -42,6 +42,7 @@ UInt64Dtype, ) from pandas.core.arrays.string_ import StringDtype +from pandas.core.case_when import case_when from pandas.core.construction import array from pandas.core.flags import Flags from pandas.core.groupby import ( @@ -86,6 +87,7 @@ "bdate_range", "BooleanDtype", "Categorical", + "case_when", "CategoricalDtype", "CategoricalIndex", "DataFrame", diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py new file mode 100644 index 0000000000000..7a08e6397322c --- /dev/null +++ b/pandas/core/case_when.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.cast import ( + construct_1d_arraylike_from_scalar, + find_common_type, + infer_dtype_from, +) +from pandas.core.dtypes.common import is_scalar +from pandas.core.dtypes.generic import ABCSeries + +from pandas.core.construction import array as pd_array + +if TYPE_CHECKING: + from pandas._typing import Series + + +def case_when( + *args: tuple[tuple], + default=lib.no_default, +) -> Series: + """ + Replace values where the conditions are True. + + Parameters + ---------- + *args : tuple(s) of array-like, scalar + Variable argument of tuples of conditions and expected replacements. + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array. + When multiple boolean conditions are satisfied, + the first replacement is used. + If ``condition`` is a Series, and the equivalent ``replacement`` + is a Series, they must have the same index. + If there are multiple replacement options, + and they are Series, they must have the same index. + + default : scalar, array-like, default None + If provided, it is the replacement value to use + if all conditions evaluate to False. + If not specified, entries will be filled with the + corresponding NULL value. + + .. versionadded:: 2.2.0 + + Returns + ------- + Series + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> df = pd.DataFrame({ + ... "a": [0,0,1,2], + ... "b": [0,3,4,5], + ... "c": [6,7,8,9] + ... }) + >>> df + a b c + 0 0 0 6 + 1 0 3 7 + 2 1 4 8 + 3 2 5 9 + + >>> pd.case_when((df.a.gt(0), df.a), # condition, replacement + ... (df.b.gt(0), df.b), # condition, replacement + ... default=df.c) # optional + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 + """ + from pandas import Series + from pandas._testing.asserters import assert_index_equal + + args = validate_case_when(args=args) + + conditions, replacements = zip(*args) + common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] + + if default is not lib.no_default: + arg_dtype, _ = infer_dtype_from(default) + common_dtypes.append(arg_dtype) + else: + default = None + if len(set(common_dtypes)) > 1: + common_dtypes = find_common_type(common_dtypes) + updated_replacements = [] + for condition, replacement in zip(conditions, replacements): + if is_scalar(replacement): + replacement = construct_1d_arraylike_from_scalar( + value=replacement, length=len(condition), dtype=common_dtypes + ) + elif isinstance(replacement, ABCSeries): + replacement = replacement.astype(common_dtypes) + else: + replacement = pd_array(replacement, dtype=common_dtypes) + updated_replacements.append(replacement) + replacements = updated_replacements + if (default is not None) and isinstance(default, ABCSeries): + default = default.astype(common_dtypes) + else: + common_dtypes = common_dtypes[0] + if not isinstance(default, ABCSeries): + cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)] + replacement_indices = [ + replacement + for replacement in replacements + if isinstance(replacement, ABCSeries) + ] + cond_length = None + if replacement_indices: + for left, right in zip(replacement_indices, replacement_indices[1:]): + try: + assert_index_equal(left.index, right.index, check_order=False) + except AssertionError: + raise AssertionError( + "All replacement objects must have the same index." + ) + if cond_indices: + for left, right in zip(cond_indices, cond_indices[1:]): + try: + assert_index_equal(left.index, right.index, check_order=False) + except AssertionError: + raise AssertionError( + "All condition objects must have the same index." + ) + if replacement_indices: + try: + assert_index_equal( + replacement_indices[0].index, + cond_indices[0].index, + check_order=False, + ) + except AssertionError: + raise ValueError( + "All replacement objects and condition objects " + "should have the same index." + ) + else: + conditions = [ + np.asanyarray(cond) if not hasattr(cond, "shape") else cond + for cond in conditions + ] + cond_length = {len(cond) for cond in conditions} + if len(cond_length) > 1: + raise ValueError("The boolean conditions should have the same length.") + cond_length = len(conditions[0]) + if not is_scalar(default): + if len(default) != cond_length: + raise ValueError( + "length of `default` does not match the length " + "of any of the conditions." + ) + if not replacement_indices: + for num, replacement in enumerate(replacements): + if is_scalar(replacement): + continue + if not hasattr(replacement, "shape"): + replacement = np.asanyarray(replacement) + if len(replacement) != cond_length: + raise ValueError( + f"Length of condition{num} does not match " + f"the length of replacement{num}; " + f"{cond_length} != {len(replacement)}" + ) + if cond_indices: + default_index = cond_indices[0].index + elif replacement_indices: + default_index = replacement_indices[0].index + else: + default_index = range(cond_length) + default = Series(default, index=default_index, dtype=common_dtypes) + counter = reversed(range(len(conditions))) + for position, condition, replacement in zip( + counter, conditions[::-1], replacements[::-1] + ): + try: + default = default.mask( + condition, other=replacement, axis=0, inplace=False, level=None + ) + except Exception as error: + raise ValueError( + f"condition{position} and replacement{position} failed to evaluate." + ) from error + return default + + +def validate_case_when(args: tuple) -> tuple: + """ + Validates the variable arguments for the case_when function. + """ + if not len(args): + raise ValueError( + "provide at least one boolean condition, " + "with a corresponding replacement." + ) + + for num, entry in enumerate(args): + if not isinstance(entry, tuple): + raise TypeError(f"Argument {num} must be a tuple; got {type(entry)}.") + if len(entry) != 2: + raise ValueError( + f"Argument {num} must have length 2; " + "a condition and replacement; " + f"got length {len(entry)}." + ) + return args diff --git a/pandas/core/series.py b/pandas/core/series.py index 1bbd10429ea22..849494767bbaf 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -106,6 +106,10 @@ from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor from pandas.core.arrays.string_ import StringDtype +from pandas.core.case_when import ( + case_when, + validate_case_when, +) from pandas.core.construction import ( extract_array, sanitize_array, @@ -5554,6 +5558,80 @@ def between( return lmask & rmask + def case_when( + self, + *args: tuple[tuple], + ) -> Series: + """ + Replace values where the conditions are True. + + Parameters + ---------- + *args : tuple(s) of array-like, scalar. + Variable argument of tuples of conditions and expected replacements. + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array-like object + or a callable. If ``condition`` is a callable, + it is computed on the Series + and should return a boolean Series or array. + The callable must not change the input Series + (though pandas doesn`t check it). ``replacement`` should be a + 1-D array-like object, a scalar or a callable. + If ``replacement`` is a callable, it is computed on the Series + and should return a scalar or Series. The callable + must not change the input Series + (though pandas doesn`t check it). + If ``condition`` is a Series, and the equivalent ``replacement`` + is a Series, they must have the same index. + If there are multiple replacement options, + and they are Series, they must have the same index. + + level : int, default None + Alignment level if needed. + + .. versionadded:: 2.2.0 + + Returns + ------- + Series + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> df = pd.DataFrame({ + ... "a": [0,0,1,2], + ... "b": [0,3,4,5], + ... "c": [6,7,8,9] + ... }) + >>> df + a b c + 0 0 0 6 + 1 0 3 7 + 2 1 4 8 + 3 2 5 9 + + >>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement + ... (df.b.gt(0), df.b)) + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 + """ + args = validate_case_when(args) + args = [ + ( + com.apply_if_callable(condition, self), + com.apply_if_callable(replacement, self), + ) + for condition, replacement in args + ] + return case_when(*args, default=self, level=None) + # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] def isna(self) -> Series: diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 60bcb97aaa364..ae4fe5d56ebf6 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -106,6 +106,7 @@ class TestPDApi(Base): funcs = [ "array", "bdate_range", + "case_when", "concat", "crosstab", "cut", diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py new file mode 100644 index 0000000000000..0b231b86ac58d --- /dev/null +++ b/pandas/tests/test_case_when.py @@ -0,0 +1,168 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + array as pd_array, + case_when, + date_range, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + """ + base dataframe for testing + """ + return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +def test_case_when_no_args(): + """ + Raise ValueError if no args is provided. + """ + msg = "provide at least one boolean condition, " + msg += "with a corresponding replacement." + with pytest.raises(ValueError, match=msg): + case_when() + + +def test_case_when_odd_args(df): + """ + Raise ValueError if no of args is odd. + """ + msg = "Argument 0 must have length 2; " + msg += "a condition and replacement; got length 3." + + with pytest.raises(ValueError, match=msg): + case_when((df["a"].eq(1), 1, df.a.gt(1))) + + +def test_case_when_raise_error_from_mask(df): + """ + Raise Error from within Series.mask + """ + msg = "condition0 and replacement0 failed to evaluate." + with pytest.raises(ValueError, match=msg): + case_when((df["a"].eq(1), df)) + + +def test_case_when_error_multiple_replacements_series(df): + """ + Test output when the replacements indices are different. + """ + with pytest.raises( + AssertionError, match="All replacement objects must have the same index." + ): + case_when( + ([True, False, False], Series(1)), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + + +def test_case_when_error_multiple_conditions_series(df): + """ + Test output when the conditions indices are different. + """ + with pytest.raises( + AssertionError, match="All condition objects must have the same index." + ): + case_when( + (Series([True, False, False], index=[2, 3, 4]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + + +def test_case_when_raise_error_different_index_condition_and_replacements(df): + """ + Raise if the replacement index and condition index are different. + """ + msg = "All replacement objects and condition objects " + msg += "should have the same index." + with pytest.raises(AssertionError, match=msg): + case_when( + (df.a.eq(1), 1), (Series([False, True, False], index=["a", "b", "c"]), 2) + ) + + +def test_case_when_single_condition(df): + """ + Test output on a single condition. + """ + result = case_when((df.a.eq(1), 1)) + expected = Series([1, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions(df): + """ + Test output when booleans are derived from a computation + """ + result = case_when((df.a.eq(1), 1), (Series([False, True, False]), 2)) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_list(df): + """ + Test output when replacement is a list + """ + result = case_when( + ([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3]) + ) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_extension_dtype(df): + """ + Test output when replacement has an extension dtype + """ + result = case_when( + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), + ) + expected = Series([1, 2, np.nan], dtype="Int64") + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_series(df): + """ + Test output when replacement is a Series + """ + result = case_when( + (np.array([True, False, False]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_default_is_not_none(df): + """ + Test output when default is not None + """ + result = case_when( + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + default=-1, + ) + expected = Series([1, 2, -1]) + tm.assert_series_equal(result, expected) + + +def test_case_when_non_range_index(): + """ + Test output if index is not RangeIndex + """ + rng = np.random.default_rng(seed=123) + dates = date_range("1/1/2000", periods=8) + df = DataFrame( + rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] + ) + result = case_when((df.A.gt(0), df.B), default=5) + result = Series(result, name="A") + expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) + tm.assert_series_equal(result, expected)