diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 64c7503849de2..1f5c3c88c5ff5 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -200,7 +200,7 @@ Missing MultiIndex ^^^^^^^^^^ -- +- Bug in :meth:`MultiIndex.set_levels` not preserving dtypes for :class:`Categorical` (:issue:`52125`) - I/O diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 580a1901fc2da..abe4a00e0b813 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -73,6 +73,7 @@ ABCDatetimeIndex, ABCTimedeltaIndex, ) +from pandas.core.dtypes.inference import is_array_like from pandas.core.dtypes.missing import ( array_equivalent, isna, @@ -945,7 +946,11 @@ def set_levels( FrozenList([['a', 'b', 'c'], [1, 2, 3, 4]]) """ - if is_list_like(levels) and not isinstance(levels, Index): + if isinstance(levels, Index): + pass + elif is_array_like(levels): + levels = Index(levels) + elif is_list_like(levels): levels = list(levels) level, levels = _require_listlike(level, levels, "Levels") diff --git a/pandas/tests/indexes/multi/test_get_set.py b/pandas/tests/indexes/multi/test_get_set.py index 70350f0df821b..8f5bba7debf2a 100644 --- a/pandas/tests/indexes/multi/test_get_set.py +++ b/pandas/tests/indexes/multi/test_get_set.py @@ -367,3 +367,11 @@ def test_set_levels_pos_args_removal(): with pytest.raises(TypeError, match="positional arguments"): idx.set_codes([[0, 1], [1, 0]], 0) + + +def test_set_levels_categorical_keep_dtype(): + # GH#52125 + midx = MultiIndex.from_arrays([[5, 6]]) + result = midx.set_levels(levels=pd.Categorical([1, 2]), level=0) + expected = MultiIndex.from_arrays([pd.Categorical([1, 2])]) + tm.assert_index_equal(result, expected)