diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index 68844aeeb081e..c29f5e78b033f 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -815,6 +815,22 @@ def test_astype_extension_dtypes_duplicate_col(self, dtype): expected = concat([a1.astype(dtype), a2.astype(dtype)], axis=1) tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize("kwargs", [dict(), dict(other=None)]) + def test_df_where_with_category(self, kwargs): + # GH 16979 + df = DataFrame(np.arange(2 * 3).reshape(2, 3), columns=list("ABC")) + mask = np.array([[True, False, True], [False, True, True]]) + + # change type to category + df.A = df.A.astype("category") + df.B = df.B.astype("category") + df.C = df.C.astype("category") + + result = df.A.where(mask[:, 0], **kwargs) + expected = Series(pd.Categorical([0, np.nan], categories=[0, 3]), name="A") + + tm.assert_series_equal(result, expected) + @pytest.mark.parametrize( "dtype", [{100: "float64", 200: "uint64"}, "category", "float64"] )