From be34e24aa3aaaf23b9fee881b3d0f86aa23a8891 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Sun, 9 Apr 2023 00:27:17 +0200 Subject: [PATCH 1/6] BUG: __from_arrow__ not respecting explicit dtype --- doc/source/whatsnew/v2.0.1.rst | 1 + pandas/core/arrays/arrow/dtype.py | 4 +++- pandas/tests/extension/test_arrow.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index caf237fb15163..bb25a4b3ad17e 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -26,6 +26,7 @@ Bug fixes - Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`) - Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`) - Bug in :func:`pandas.testing.assert_series_equal` where ``check_dtype=False`` would still raise for datetime or timedelta types with different resolutions (:issue:`52449`) +- Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`) .. --------------------------------------------------------------------------- .. _whatsnew_201.other: diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 2c719c0405e14..67f274dfde971 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -27,6 +27,7 @@ if not pa_version_under7p0: import pyarrow as pa + import pyarrow.compute as pc if TYPE_CHECKING: from pandas._typing import ( @@ -319,4 +320,5 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray. """ array_class = self.construct_array_type() - return array_class(array) + arr = pc.cast(array, self.pyarrow_dtype, safe=True) + return array_class(arr) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 99f2a7d820dcb..9d5eb8ce56eba 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1739,6 +1739,20 @@ def test_setitem_invalid_dtype(data): data[:] = fill_value +def test_from_arrow_respecting_given_dtype(): + date_array = pa.array( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32() + ) + result = date_array.to_pandas( + types_mapper={pa.date32(): ArrowDtype(pa.date64())}.get + ) + expected = pd.Series( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], + dtype=ArrowDtype(pa.date64()), + ) + tm.assert_series_equal(result, expected) + + def test_round(): dtype = "float64[pyarrow]" From b790dcc275d999f9130a086002e4d5a6eaac9066 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Sun, 9 Apr 2023 17:50:22 +0200 Subject: [PATCH 2/6] Skip for arrow 7 --- pandas/tests/extension/test_arrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9d5eb8ce56eba..e2495f2caf11f 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1739,6 +1739,7 @@ def test_setitem_invalid_dtype(data): data[:] = fill_value +@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") def test_from_arrow_respecting_given_dtype(): date_array = pa.array( [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32() From 803561df0dc9864816b15ca2ac51689163e95fc7 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 10 Apr 2023 21:19:24 +0200 Subject: [PATCH 3/6] Add test --- pandas/core/arrays/arrow/dtype.py | 3 +-- pandas/tests/extension/test_arrow.py | 7 +++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 67f274dfde971..c416fbd03417a 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -27,7 +27,6 @@ if not pa_version_under7p0: import pyarrow as pa - import pyarrow.compute as pc if TYPE_CHECKING: from pandas._typing import ( @@ -320,5 +319,5 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray. """ array_class = self.construct_array_type() - arr = pc.cast(array, self.pyarrow_dtype, safe=True) + arr = array.cast(self.pyarrow_dtype, safe=True) return array_class(arr) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e2495f2caf11f..b0191daefb528 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1754,6 +1754,13 @@ def test_from_arrow_respecting_given_dtype(): tm.assert_series_equal(result, expected) +@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") +def test_from_arrow_respecting_given_dtype_unsafe(): + array = pa.array([1.5, 2.5], type=pa.float64()) + with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"): + array.to_pandas(types_mapper={pa.float64(): ArrowDtype(pa.int64())}.get) + + def test_round(): dtype = "float64[pyarrow]" From 66e47e23cfecfd4946ebabd05a5afbc55d44780f Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 10 Apr 2023 21:19:58 +0200 Subject: [PATCH 4/6] Remove skip --- pandas/tests/extension/test_arrow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index b0191daefb528..90da136ae0a6f 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1754,7 +1754,6 @@ def test_from_arrow_respecting_given_dtype(): tm.assert_series_equal(result, expected) -@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") def test_from_arrow_respecting_given_dtype_unsafe(): array = pa.array([1.5, 2.5], type=pa.float64()) with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"): From 84332dea6b552cdc772bf808305cde9562a52781 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 10 Apr 2023 22:45:29 +0200 Subject: [PATCH 5/6] Skip --- pandas/tests/extension/test_arrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 90da136ae0a6f..b0191daefb528 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1754,6 +1754,7 @@ def test_from_arrow_respecting_given_dtype(): tm.assert_series_equal(result, expected) +@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") def test_from_arrow_respecting_given_dtype_unsafe(): array = pa.array([1.5, 2.5], type=pa.float64()) with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"): From 64ebd8ae27536340ed3d3c44c78626d1fc082e77 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 10 Apr 2023 22:45:41 +0200 Subject: [PATCH 6/6] Skip --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index b0191daefb528..8c1f914709a81 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1754,7 +1754,7 @@ def test_from_arrow_respecting_given_dtype(): tm.assert_series_equal(result, expected) -@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0") +@pytest.mark.skipif(pa_version_under8p0, reason="doesn't raise with 7") def test_from_arrow_respecting_given_dtype_unsafe(): array = pa.array([1.5, 2.5], type=pa.float64()) with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"):