diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index 7fbf2533428dc..5f0af8859a133 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -768,6 +768,7 @@ Conversion - Bug in :attr:`Timestamp.weekday_name` returning a UTC-based weekday name when localized to a timezone (:issue:`17354`) - Bug in ``Timestamp.replace`` when replacing ``tzinfo`` around DST changes (:issue:`15683`) - Bug in ``Timedelta`` construction and arithmetic that would not propagate the ``Overflow`` exception (:issue:`17367`) +- Bug in :meth:`~DataFrame.astype` converting to object dtype when passeed extension type classes (`DatetimeTZDtype``, ``CategoricalDtype``) rather than instances. Now a ``TypeError`` is raised when a class is passed (:issue:`17780`). Indexing ^^^^^^^^ diff --git a/pandas/core/internals.py b/pandas/core/internals.py index a8f1a0c78c238..689f5521e1ccb 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -1,6 +1,7 @@ import warnings import copy from warnings import catch_warnings +import inspect import itertools import re import operator @@ -552,6 +553,11 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, list(errors_legal_values), errors)) raise ValueError(invalid_arg) + if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype): + msg = ("Expected an instance of {}, but got the class instead. " + "Try instantiating 'dtype'.".format(dtype.__name__)) + raise TypeError(msg) + # may need to convert to categorical # this is only called for non-categoricals if self.is_categorical_astype(dtype): diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index 5941b2ab7c2cb..abb528f0d2179 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -612,6 +612,20 @@ def test_astype_duplicate_col(self): expected = concat([a1_str, b, a2_str], axis=1) assert_frame_equal(result, expected) + @pytest.mark.parametrize("cls", [ + pd.api.types.CategoricalDtype, + pd.api.types.DatetimeTZDtype, + pd.api.types.IntervalDtype + ]) + def test_astype_categoricaldtype_class_raises(self, cls): + df = DataFrame({"A": ['a', 'a', 'b', 'c']}) + xpr = "Expected an instance of {}".format(cls.__name__) + with tm.assert_raises_regex(TypeError, xpr): + df.astype({"A": cls}) + + with tm.assert_raises_regex(TypeError, xpr): + df['A'].astype(cls) + def test_timedeltas(self): df = DataFrame(dict(A=Series(date_range('2012-1-1', periods=3, freq='D')),