diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 905073645fcb3..4a0bf67f47bae 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -9,7 +9,8 @@ from pandas.core.dtypes.dtypes import ( registry, CategoricalDtype, CategoricalDtypeType, DatetimeTZDtype, DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype, - IntervalDtypeType, ExtensionDtype) + IntervalDtypeType, PandasExtensionDtype, ExtensionDtype, + _pandas_registry) from pandas.core.dtypes.generic import ( ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries, ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, ABCIndexClass, @@ -1709,17 +1710,9 @@ def is_extension_array_dtype(arr_or_dtype): Third-party libraries may implement arrays or types satisfying this interface as well. """ - from pandas.core.arrays import ExtensionArray - - if isinstance(arr_or_dtype, (ABCIndexClass, ABCSeries)): - arr_or_dtype = arr_or_dtype._values - - try: - arr_or_dtype = pandas_dtype(arr_or_dtype) - except TypeError: - pass - - return isinstance(arr_or_dtype, (ExtensionDtype, ExtensionArray)) + dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype) + return (isinstance(dtype, ExtensionDtype) or + registry.find(dtype) is not None) def is_complex_dtype(arr_or_dtype): @@ -1999,12 +1992,12 @@ def pandas_dtype(dtype): return dtype # registered extension types - result = registry.find(dtype) + result = _pandas_registry.find(dtype) or registry.find(dtype) if result is not None: return result # un-registered extension types - elif isinstance(dtype, ExtensionDtype): + elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)): return dtype # try a numpy dtype diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index cf771a127a696..f53ccc86fc4ff 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -22,9 +22,9 @@ class Registry(object): -------- registry.register(MyExtensionDtype) """ - dtypes = [] + def __init__(self): + self.dtypes = [] - @classmethod def register(self, dtype): """ Parameters @@ -50,7 +50,7 @@ def find(self, dtype): dtype_type = dtype if not isinstance(dtype, type): dtype_type = type(dtype) - if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)): + if issubclass(dtype_type, ExtensionDtype): return dtype return None @@ -65,6 +65,9 @@ def find(self, dtype): registry = Registry() +# TODO(Extension): remove the second registry once all internal extension +# dtypes are real extension dtypes. +_pandas_registry = Registry() class PandasExtensionDtype(_DtypeOpsMixin): @@ -822,7 +825,7 @@ def is_dtype(cls, dtype): # register the dtypes in search order -registry.register(DatetimeTZDtype) -registry.register(PeriodDtype) registry.register(IntervalDtype) registry.register(CategoricalDtype) +_pandas_registry.register(DatetimeTZDtype) +_pandas_registry.register(PeriodDtype) diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index 02ac7fc7d5ed7..55c841ba1fc46 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -9,7 +9,7 @@ from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, PeriodDtype, - IntervalDtype, CategoricalDtype, registry) + IntervalDtype, CategoricalDtype, registry, _pandas_registry) from pandas.core.dtypes.common import ( is_categorical_dtype, is_categorical, is_datetime64tz_dtype, is_datetimetz, @@ -775,21 +775,31 @@ def test_update_dtype_errors(self, bad_dtype): @pytest.mark.parametrize( 'dtype', - [DatetimeTZDtype, CategoricalDtype, - PeriodDtype, IntervalDtype]) + [CategoricalDtype, IntervalDtype]) def test_registry(dtype): assert dtype in registry.dtypes +@pytest.mark.parametrize('dtype', [DatetimeTZDtype, PeriodDtype]) +def test_pandas_registry(dtype): + assert dtype not in registry.dtypes + assert dtype in _pandas_registry.dtypes + + @pytest.mark.parametrize( 'dtype, expected', [('int64', None), ('interval', IntervalDtype()), ('interval[int64]', IntervalDtype()), ('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')), - ('category', CategoricalDtype()), - ('period[D]', PeriodDtype('D')), - ('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))]) + ('category', CategoricalDtype())]) def test_registry_find(dtype, expected): - assert registry.find(dtype) == expected + + +@pytest.mark.parametrize( + 'dtype, expected', + [('period[D]', PeriodDtype('D')), + ('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))]) +def test_pandas_registry_find(dtype, expected): + assert _pandas_registry.find(dtype) == expected