diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 6923b42d3340b..109dbd5839756 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -236,7 +236,7 @@ Bug fixes Categorical ^^^^^^^^^^^ -- +- Bug in :func:`pandas.api.types.is_categorical_dtype` where sparse categorical dtypes would return ``False`` (:issue:`35793`) - Datetimelike diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 3ae5cabf9c73f..4d7e78a0f46e8 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -287,6 +287,10 @@ def is_dtype(cls, dtype: object) -> bool: return False elif isinstance(dtype, cls): return True + elif hasattr(dtype, "subtype") and isinstance( + dtype.subtype, cls # type: ignore[attr-defined] + ): + return True if isinstance(dtype, str): try: return cls.construct_from_string(dtype) is not None diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 5987fdabf78bb..bc94d4ae422d4 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -558,8 +558,10 @@ def is_categorical_dtype(arr_or_dtype) -> bool: """ if isinstance(arr_or_dtype, ExtensionDtype): # GH#33400 fastpath for dtype object - return arr_or_dtype.name == "category" - + return ("category" == arr_or_dtype.name) or ( + is_sparse(arr_or_dtype) + and arr_or_dtype.subtype.name == "category" # type: ignore[attr-defined] + ) if arr_or_dtype is None: return False return CategoricalDtype.is_dtype(arr_or_dtype) diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 2db9a9a403e1c..e36b4ae17ed0e 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -213,6 +213,15 @@ def test_is_categorical_deprecation(): com.is_categorical([1, 2, 3]) +def test_is_categorical_sparse_categorical(): + # https://github.com/pandas-dev/pandas/issues/35793 + s = pd.Series( + ["a", "b", "c"], dtype=pd.SparseDtype(CategoricalDtype(["a", "b", "c"])) + ) + assert com.is_categorical_dtype(s) + assert com.is_categorical_dtype(s.dtype) + + def test_is_datetime64_dtype(): assert not com.is_datetime64_dtype(object) assert not com.is_datetime64_dtype([1, 2, 3])