diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index caf237fb15163..bb25a4b3ad17e 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -26,6 +26,7 @@ Bug fixes - Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`) - Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`) - Bug in :func:`pandas.testing.assert_series_equal` where ``check_dtype=False`` would still raise for datetime or timedelta types with different resolutions (:issue:`52449`) +- Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`) .. --------------------------------------------------------------------------- .. _whatsnew_201.other: diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 2c719c0405e14..c416fbd03417a 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -319,4 +319,5 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray. """ array_class = self.construct_array_type() - return array_class(array) + arr = array.cast(self.pyarrow_dtype, safe=True) + return array_class(arr) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 99f2a7d820dcb..8c1f914709a81 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1739,6 +1739,28 @@ def test_setitem_invalid_dtype(data): data[:] = fill_value +@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") +def test_from_arrow_respecting_given_dtype(): + date_array = pa.array( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32() + ) + result = date_array.to_pandas( + types_mapper={pa.date32(): ArrowDtype(pa.date64())}.get + ) + expected = pd.Series( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], + dtype=ArrowDtype(pa.date64()), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.skipif(pa_version_under8p0, reason="doesn't raise with 7") +def test_from_arrow_respecting_given_dtype_unsafe(): + array = pa.array([1.5, 2.5], type=pa.float64()) + with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"): + array.to_pandas(types_mapper={pa.float64(): ArrowDtype(pa.int64())}.get) + + def test_round(): dtype = "float64[pyarrow]"