From e19fe33c8867a07cc8302b959bf77e518482cc80 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 13:58:32 +0200 Subject: [PATCH 01/16] (fix): no `fill_value` on reindex --- properties/test_pandas_roundtrip.py | 18 +++++++++---- xarray/core/dtypes.py | 5 +++- xarray/core/duck_array_ops.py | 17 ++++++++++--- xarray/tests/__init__.py | 10 ++++++++ xarray/tests/test_concat.py | 21 ++++++++-------- xarray/tests/test_dataset.py | 39 +++++++++++++++++++---------- 6 files changed, 77 insertions(+), 33 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 8fc32e75cbd..5f4fab77b7a 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -134,10 +134,18 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: xr.testing.assert_identical(dataset, roundtripped.to_xarray()) -def test_roundtrip_1d_pandas_extension_array() -> None: - df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])}) - arr = xr.Dataset.from_dataframe(df)["cat"] +@pytest.mark.parametrize( + "extension_Array", + [ + pd.Categorical(["a", "b", "c"]), + pd.array([1, 2, 3], dtype="int64"), + pd.array(["a", "b", "c"], dtype="string"), + ], +) +def test_roundtrip_1d_pandas_extension_array(extension_Array) -> None: + df = pd.DataFrame({"arr": extension_Array}) + arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() - assert (df["cat"] == roundtripped).all() - assert df["cat"].dtype == roundtripped.dtype + assert (df["arr"] == roundtripped).all() + assert df["arr"].dtype == roundtripped.dtype xr.testing.assert_identical(arr, roundtripped.to_xarray()) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..132e6553412 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -4,6 +4,7 @@ from typing import Any import numpy as np +import pandas as pd from pandas.api.types import is_extension_array_dtype from xarray.compat import array_api_compat, npcompat @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): + if pd.api.types.is_extension_array_dtype(dtype): + return dtype, pd.NA + elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): # for now, we always promote string dtypes to object for consistency with existing behavior # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes dtype_ = object diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 48dc3b7627a..062c2526674 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -273,15 +273,24 @@ def as_shared_dtype(scalars_or_arrays, xp=None): extension_array_types = [ x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) ] - non_nans = [x for x in scalars_or_arrays if not isna(x)] - if len(extension_array_types) == len(non_nans) and all( + non_nans_or_scalar = [ + x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x)) + ] + if len(extension_array_types) == len(non_nans_or_scalar) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): + extension_array_class = type( + non_nans_or_scalar[0].array + if isinstance(non_nans_or_scalar[0], PandasExtensionArray) + else non_nans_or_scalar[0] + ) return [ x - if not isna(x) + if not (isna(x) or np.isscalar(x)) else PandasExtensionArray( - type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) + extension_array_class._from_sequence( + [x], dtype=non_nans_or_scalar[0].dtype + ) ) for x in scalars_or_arrays ] diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b33192393f7..dcdf1f1efeb 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,6 +363,16 @@ def create_test_data( ) ), ) + obj["var5"] = ( + "dim1", + pd.array( + rs.integers(1, 10, size=dim_sizes[0]).tolist(), dtype=pd.Int64Dtype() + ), + ) + obj["var6"] = ( + "dim1", + pd.array(list(string.ascii_lowercase[: dim_sizes[0]]), dtype="string"), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 49c6490d819..5f6909d032c 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -154,19 +154,20 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -def test_concat_categorical() -> None: +def test_concat_extension_array() -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") - assert ( - concatenated["var4"] - == type(data2["var4"].variable.data)._concat_same_type( - [ - data1["var4"].variable.data, - data2["var4"].variable.data, - ] - ) - ).all() + for var in ["var4", "var5"]: + assert ( + concatenated[var] + == type(data2[var].variable.data)._concat_same_type( + [ + data1[var].variable.data, + data2[var].variable.data, + ] + ) + ).all() def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bacad96a213..3aeadd7ff1b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -298,6 +298,8 @@ def test_repr(self) -> None: var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' + var5 (dim1) Int64 72B 5 9 7 2 6 2 8 1 + var6 (dim1) string 64B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' Attributes: foo: bar""".format( data["dim3"].dtype, @@ -1827,25 +1829,36 @@ def test_categorical_index_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() - @pytest.mark.parametrize("fill_value", [np.nan, pd.NA]) - def test_extensionarray_negative_reindex(self, fill_value) -> None: - cat = pd.Categorical( - ["foo", "bar", "baz"], - categories=["foo", "bar", "baz", "qux", "quux", "corge"], - ) + @pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None]) + @pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux"], + ), + pd.array([1, 2, 3], dtype=pd.Int32Dtype()), + pd.array(["a", "b", "c"], dtype="string"), + ], + ) + def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( - {"cat": ("index", cat)}, + {"arr": ("index", extension_array)}, coords={"index": ("index", np.arange(3))}, ) + kwargs = {} + if fill_value is not None: + kwargs["fill_value"] = fill_value reindexed_cat = cast( pd.api.extensions.ExtensionArray, - ( - ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"] - .to_pandas() - .values - ), + (ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values), ) - assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined] + assert reindexed_cat.equals( + pd.array( + [pd.NA, extension_array[1], extension_array[1]], + dtype=extension_array.dtype, + ) + ) # type: ignore[attr-defined] def test_extension_array_reindex_same(self) -> None: series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) From 05f37f6eecc4412ea02eafefa49cb006a0932a3b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 14:29:08 +0200 Subject: [PATCH 02/16] (chore): add `fillna` tests --- xarray/core/duck_array_ops.py | 19 +++++++++++++------ xarray/tests/test_dataarray.py | 7 +++++++ xarray/tests/test_dataset.py | 7 +++++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 062c2526674..9d7eaf055eb 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -279,21 +279,29 @@ def as_shared_dtype(scalars_or_arrays, xp=None): if len(extension_array_types) == len(non_nans_or_scalar) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): + # Get the extension array class of the first element, guaranteed to be the same + # as the others thanks to the anove check. extension_array_class = type( non_nans_or_scalar[0].array if isinstance(non_nans_or_scalar[0], PandasExtensionArray) else non_nans_or_scalar[0] ) - return [ + # Cast scalars/nans to extension array class + arrays_with_nan_to_sequence = [ x if not (isna(x) or np.isscalar(x)) - else PandasExtensionArray( - extension_array_class._from_sequence( - [x], dtype=non_nans_or_scalar[0].dtype - ) + else extension_array_class._from_sequence( + [x], dtype=non_nans_or_scalar[0].dtype ) for x in scalars_or_arrays ] + # Wrap the output if necessary + return [ + PandasExtensionArray(x) + if not isinstance(x, PandasExtensionArray) + else x + for x in arrays_with_nan_to_sequence + ] raise ValueError( f"Cannot cast values to shared type, found values: {scalars_or_arrays}" ) @@ -416,7 +424,6 @@ def where(condition, x, y): condition = asarray(condition, dtype=dtype, xp=xp) else: condition = astype(condition, dtype=dtype, xp=xp) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e7acdcdd4f3..0fc29dad8a6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3075,6 +3075,13 @@ def test_propagate_attrs(self, func) -> None: with set_options(keep_attrs=True): assert func(da).attrs == da.attrs + def test_fillna_extension_array_int(self) -> None: + srs = pd.Series(index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])) + da = srs.to_xarray() + filled = da.fillna(0) + assert filled.dtype == pd.Int64Dtype() + assert (filled.values == np.array([0, 1, 1])).all() + def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3aeadd7ff1b..ccd9fa7f727 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5456,6 +5456,13 @@ def test_dropna(self) -> None: with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) # type: ignore[arg-type] + def test_fillna_extension_array_int(self) -> None: + srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + filled = ds.fillna(0) + assert filled.dtype == pd.Int64Dtype() + assert (filled.values == np.array([0, 1, 1])).all() + def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) From 137c4ab1d3706c9bad1ff3ee988d2872fe7c73c7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 14:31:16 +0200 Subject: [PATCH 03/16] (fix): arbitrary captialization --- properties/test_pandas_roundtrip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5f4fab77b7a..ce34f32cf27 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -135,7 +135,7 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: @pytest.mark.parametrize( - "extension_Array", + "extension_array", [ pd.Categorical(["a", "b", "c"]), pd.array([1, 2, 3], dtype="int64"), From 3b499cf9c563bc08347235fe0f4c1b1627a6b711 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 14:34:43 +0200 Subject: [PATCH 04/16] (chore): add interval array explicitly checked --- properties/test_pandas_roundtrip.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index ce34f32cf27..f7f8e1beac0 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -140,10 +140,13 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: pd.Categorical(["a", "b", "c"]), pd.array([1, 2, 3], dtype="int64"), pd.array(["a", "b", "c"], dtype="string"), + pd.arrays.IntervalArray( + [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] + ), ], ) -def test_roundtrip_1d_pandas_extension_array(extension_Array) -> None: - df = pd.DataFrame({"arr": extension_Array}) +def test_roundtrip_1d_pandas_extension_array(extension_array) -> None: + df = pd.DataFrame({"arr": extension_array}) arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() assert (df["arr"] == roundtripped).all() From 492078265403a97a40fd80a83d9f13b61c006439 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 14:41:04 +0200 Subject: [PATCH 05/16] (fix): ds values errors + `dropna` tests --- xarray/tests/test_dataarray.py | 7 +++++++ xarray/tests/test_dataset.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0fc29dad8a6..97f1575241d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3082,6 +3082,13 @@ def test_fillna_extension_array_int(self) -> None: assert filled.dtype == pd.Int64Dtype() assert (filled.values == np.array([0, 1, 1])).all() + def test_dropna_extension_array_int(self) -> None: + srs = pd.Series(index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])) + da = srs.to_xarray() + filled = da.dropna("index") + assert filled.dtype == pd.Int64Dtype() + assert (filled.values == np.array([1, 1])).all() + def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ccd9fa7f727..2d21680c984 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5460,8 +5460,15 @@ def test_fillna_extension_array_int(self) -> None: srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3])) ds = srs.to_xarray() filled = ds.fillna(0) - assert filled.dtype == pd.Int64Dtype() - assert (filled.values == np.array([0, 1, 1])).all() + assert filled["data"].dtype == pd.Int64Dtype() + assert (filled["data"].values == np.array([0, 1, 1])).all() + + def test_dropna_extension_array_int(self) -> None: + srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + dropped = ds.dropna("index") + assert dropped["data"].dtype == pd.Int64Dtype() + assert (dropped["data"].values == np.array([1, 1])).all() def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) From 6c97217298180b40d8ada1330b7e65cd4b1e8bbe Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 May 2025 14:51:39 +0200 Subject: [PATCH 06/16] (fix): mypy --- xarray/tests/test_dataarray.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 97f1575241d..979299e59b7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3076,14 +3076,18 @@ def test_propagate_attrs(self, func) -> None: assert func(da).attrs == da.attrs def test_fillna_extension_array_int(self) -> None: - srs = pd.Series(index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])) + srs: pd.Series = pd.Series( + index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1]) + ) da = srs.to_xarray() filled = da.fillna(0) assert filled.dtype == pd.Int64Dtype() assert (filled.values == np.array([0, 1, 1])).all() def test_dropna_extension_array_int(self) -> None: - srs = pd.Series(index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])) + srs: pd.Series = pd.Series( + index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1]) + ) da = srs.to_xarray() filled = da.dropna("index") assert filled.dtype == pd.Int64Dtype() From c561e4e9b617df101d254dd458fefa583e8e20ca Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 14:30:14 +0200 Subject: [PATCH 07/16] (fix): add in `reshape` implemented + remove integer tests --- xarray/core/extension_array.py | 13 ++++++++++++- xarray/core/formatting.py | 7 ++++++- xarray/tests/__init__.py | 6 +----- xarray/tests/test_dataarray.py | 14 +++++++------- xarray/tests/test_dataset.py | 8 +++----- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 096a427e425..e27e53c7747 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -51,6 +51,17 @@ def __extension_duck_array__concatenate( return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] +@implements(np.reshape) +def __extension_duck_array__concatenate( + arr: T_ExtensionArray, shape: tuple +) -> T_ExtensionArray: + if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): + return arr + raise NotImplementedError( + f"Cannot reshape 1d-only pandas extension array to: {shape}" + ) + + @implements(np.where) def __extension_duck_array__where( condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray @@ -103,7 +114,7 @@ def replace_duck_with_extension_array(args) -> list: return func(*args, **kwargs) res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): - return type(self)[type(res)](res) + return PandasExtensionArray(res) return res def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 7aa333ffb2e..86fb147d382 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -19,6 +19,7 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -176,6 +177,8 @@ def format_timedelta(t, timedelta_format=None): def format_item(x, timedelta_format=None, quote_strings=True): """Returns a succinct summary of an object as a string""" + if isinstance(x, PandasExtensionArray): + return f"{x.array[0]}" if isinstance(x, np.datetime64 | datetime): return format_timestamp(x) if isinstance(x, np.timedelta64 | timedelta): @@ -194,7 +197,9 @@ def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" x = to_duck_array(x) timedelta_format = "datetime" - if np.issubdtype(x.dtype, np.timedelta64): + if not isinstance(x, PandasExtensionArray) and np.issubdtype( + x.dtype, np.timedelta64 + ): x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index dcdf1f1efeb..8e7503cde4b 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -366,13 +366,9 @@ def create_test_data( obj["var5"] = ( "dim1", pd.array( - rs.integers(1, 10, size=dim_sizes[0]).tolist(), dtype=pd.Int64Dtype() + rs.integers(1, 10, size=dim_sizes[0]).tolist(), dtype="int64[pyarrow]" ), ) - obj["var6"] = ( - "dim1", - pd.array(list(string.ascii_lowercase[: dim_sizes[0]]), dtype="string"), - ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 979299e59b7..b0ff72ed9a4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3075,23 +3075,23 @@ def test_propagate_attrs(self, func) -> None: with set_options(keep_attrs=True): assert func(da).attrs == da.attrs - def test_fillna_extension_array_int(self) -> None: + def test_fillna_extension_array(self) -> None: srs: pd.Series = pd.Series( - index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1]) + index=np.array([1, 2, 3]), data=pd.Categorical([pd.NA, "a", "b"]) ) da = srs.to_xarray() filled = da.fillna(0) - assert filled.dtype == pd.Int64Dtype() + assert filled.dtype == srs.dtype assert (filled.values == np.array([0, 1, 1])).all() - def test_dropna_extension_array_int(self) -> None: + def test_dropna_extension_array(self) -> None: srs: pd.Series = pd.Series( - index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1]) + index=np.array([1, 2, 3]), data=pd.Categorical([pd.NA, "a", "b"]) ) da = srs.to_xarray() filled = da.dropna("index") - assert filled.dtype == pd.Int64Dtype() - assert (filled.values == np.array([1, 1])).all() + assert filled.dtype == srs.dtype + assert (filled.values == np.array(["a", "b"])).all() def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2d21680c984..136d39b56a3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -297,9 +297,8 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' - var5 (dim1) Int64 72B 5 9 7 2 6 2 8 1 - var6 (dim1) string 64B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' + var4 (dim1) category 32B b c b a c a c a + var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1 Attributes: foo: bar""".format( data["dim3"].dtype, @@ -1837,8 +1836,7 @@ def test_categorical_index_reindex(self) -> None: ["foo", "bar", "baz"], categories=["foo", "bar", "baz", "qux"], ), - pd.array([1, 2, 3], dtype=pd.Int32Dtype()), - pd.array(["a", "b", "c"], dtype="string"), + pd.array([1, 1, None], dtype="int64[pyarrow]"), ], ) def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: From 29131c2524fd7558846698ba0e37587956ac54a9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 16:51:49 +0200 Subject: [PATCH 08/16] (fix): add different test cases for fillna + dropna --- xarray/core/extension_array.py | 4 ++++ xarray/tests/test_dataarray.py | 39 ++++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index e27e53c7747..0272c62eaac 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -121,6 +121,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + if isinstance(key, tuple): + if len(key) > 1: + raise IndexError("Too many indices for array.") + key = key[0] item = self.array[key] if is_extension_array_dtype(item): return PandasExtensionArray(item) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b0ff72ed9a4..033046fad9f 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3075,23 +3075,44 @@ def test_propagate_attrs(self, func) -> None: with set_options(keep_attrs=True): assert func(da).attrs == da.attrs - def test_fillna_extension_array(self) -> None: - srs: pd.Series = pd.Series( - index=np.array([1, 2, 3]), data=pd.Categorical([pd.NA, "a", "b"]) - ) + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + ("a", pd.Categorical([pd.NA, "a", "b"])), + (0, pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]")), + ], + ids=["categorical", "int64[pyarrow]"], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) da = srs.to_xarray() - filled = da.fillna(0) + filled = da.fillna(fill_value) assert filled.dtype == srs.dtype - assert (filled.values == np.array([0, 1, 1])).all() + assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all() - def test_dropna_extension_array(self) -> None: + def test_fillna_extension_array_bad_val(self) -> None: srs: pd.Series = pd.Series( - index=np.array([1, 2, 3]), data=pd.Categorical([pd.NA, "a", "b"]) + index=np.array([1, 2, 3]), + data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), ) da = srs.to_xarray() + with pytest.raises(ValueError): + da.fillna("a") + + @pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical([pd.NA, "a", "b"]), + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + ], + ids=["categorical", "int64[pyarrow]"], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() filled = da.dropna("index") assert filled.dtype == srs.dtype - assert (filled.values == np.array(["a", "b"])).all() + assert (filled.values == srs.values[1:]).all() def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") From cf118049c609908b54deb7817f7625ada661aea1 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 17:13:06 +0200 Subject: [PATCH 09/16] (fix): `dataset` tests + marks for pyarrow --- xarray/tests/test_dataarray.py | 12 ++++++-- xarray/tests/test_dataset.py | 51 ++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 033046fad9f..919fdb1a946 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -61,6 +61,7 @@ requires_iris, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -3079,7 +3080,11 @@ def test_propagate_attrs(self, func) -> None: "fill_value,extension_array", [ ("a", pd.Categorical([pd.NA, "a", "b"])), - (0, pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]")), + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + marks=requires_pyarrow, + ), ], ids=["categorical", "int64[pyarrow]"], ) @@ -3090,6 +3095,7 @@ def test_fillna_extension_array(self, fill_value, extension_array) -> None: assert filled.dtype == srs.dtype assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all() + @requires_pyarrow def test_fillna_extension_array_bad_val(self) -> None: srs: pd.Series = pd.Series( index=np.array([1, 2, 3]), @@ -3103,7 +3109,9 @@ def test_fillna_extension_array_bad_val(self) -> None: "extension_array", [ pd.Categorical([pd.NA, "a", "b"]), - pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), marks=requires_pyarrow + ), ], ids=["categorical", "int64[pyarrow]"], ) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 136d39b56a3..2cbec65e186 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -65,6 +65,7 @@ requires_dask, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -1836,7 +1837,9 @@ def test_categorical_index_reindex(self) -> None: ["foo", "bar", "baz"], categories=["foo", "bar", "baz", "qux"], ), - pd.array([1, 1, None], dtype="int64[pyarrow]"), + pytest.param( + pd.array([1, 1, None], dtype="int64[pyarrow]"), marks=requires_pyarrow + ), ], ) def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: @@ -1858,8 +1861,9 @@ def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> N ) ) # type: ignore[attr-defined] + @requires_pyarrow def test_extension_array_reindex_same(self) -> None: - series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) + series = pd.Series([1, 2, pd.NA, 3], dtype="int32[pyarrow]") test = xr.Dataset({"test": series}) res = test.reindex(dim_0=series.index) align(res, test, join="exact") @@ -5454,19 +5458,44 @@ def test_dropna(self) -> None: with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) # type: ignore[arg-type] - def test_fillna_extension_array_int(self) -> None: - srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3])) + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + ("a", pd.Categorical([pd.NA, "a", "b"])), + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + marks=requires_pyarrow, + ), + ], + ids=["categorical", "int64[pyarrow]"], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) ds = srs.to_xarray() - filled = ds.fillna(0) - assert filled["data"].dtype == pd.Int64Dtype() - assert (filled["data"].values == np.array([0, 1, 1])).all() + filled = ds.fillna(fill_value) + assert filled["data"].dtype == extension_array.dtype + assert ( + filled["data"].values + == np.array([fill_value, *srs["data"].values[1:]], dtype="object") + ).all() - def test_dropna_extension_array_int(self) -> None: - srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3])) + @pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical([pd.NA, "a", "b"]), + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), marks=requires_pyarrow + ), + ], + ids=["categorical", "int64[pyarrow]"], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) ds = srs.to_xarray() dropped = ds.dropna("index") - assert dropped["data"].dtype == pd.Int64Dtype() - assert (dropped["data"].values == np.array([1, 1])).all() + assert dropped["data"].dtype == extension_array.dtype + assert (dropped["data"].values == srs["data"].values[1:]).all() def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) From a7ab018c9c239219075cbf9bbe9b239d16fd4c98 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 18:30:11 +0200 Subject: [PATCH 10/16] (fix): remove error for 1d arrays --- xarray/core/extension_array.py | 8 ++++---- xarray/tests/test_concat.py | 6 ++++-- xarray/tests/test_dataarray.py | 2 +- xarray/tests/test_dataset.py | 1 + 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 0272c62eaac..cff9ae646f6 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -111,7 +111,7 @@ def replace_duck_with_extension_array(args) -> list: args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) + raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): return PandasExtensionArray(res) @@ -121,9 +121,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: - if isinstance(key, tuple): - if len(key) > 1: - raise IndexError("Too many indices for array.") + if ( + isinstance(key, tuple) and len(key) == 1 + ): # pyarrow type arrays can't handle since-length tuples key = key[0] item = self.array[key] if is_extension_array_dtype(item): diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 5f6909d032c..028acc27186 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -21,6 +21,7 @@ assert_equal, assert_identical, requires_dask, + requires_pyarrow, ) from xarray.tests.test_dataset import create_test_data @@ -154,12 +155,13 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) +@requires_pyarrow def test_concat_extension_array() -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") for var in ["var4", "var5"]: - assert ( + assert pd.Series( concatenated[var] == type(data2[var].variable.data)._concat_same_type( [ @@ -167,7 +169,7 @@ def test_concat_extension_array() -> None: data2[var].variable.data, ] ) - ).all() + ).all() # need to wrap in series because pyarrow bool does not support `all` def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 919fdb1a946..7df6f16e18e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3684,7 +3684,7 @@ def test_series_categorical_index(self) -> None: s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) arr = DataArray(s) - assert "'a'" in repr(arr) # should not error + assert "a a b b c" in repr(arr) # should not error @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("data", ["list", "array", True]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2cbec65e186..9cb544b294f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1841,6 +1841,7 @@ def test_categorical_index_reindex(self) -> None: pd.array([1, 1, None], dtype="int64[pyarrow]"), marks=requires_pyarrow ), ], + ids=["categorical", "int64[pyarrow]"], ) def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( From 4a2170a94017448b888d8c0bc4cf77e546c40ec2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 19:04:17 +0200 Subject: [PATCH 11/16] (fix): test deps for min versions --- xarray/tests/__init__.py | 14 +++++---- xarray/tests/test_concat.py | 23 +++++++-------- xarray/tests/test_dataarray.py | 27 ++++++++++------- xarray/tests/test_dataset.py | 53 ++++++++++++++++++++++------------ 4 files changed, 70 insertions(+), 47 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 8e7503cde4b..fe76df75fa0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,12 +363,14 @@ def create_test_data( ) ), ) - obj["var5"] = ( - "dim1", - pd.array( - rs.integers(1, 10, size=dim_sizes[0]).tolist(), dtype="int64[pyarrow]" - ), - ) + if has_pyarrow: + obj["var5"] = ( + "dim1", + pd.array( + rs.integers(1, 10, size=dim_sizes[0]).tolist(), + dtype="int64[pyarrow]", + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 028acc27186..ed5aac4fe99 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -155,21 +155,20 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -@requires_pyarrow -def test_concat_extension_array() -> None: +@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)]) +def test_concat_extension_array(var) -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") - for var in ["var4", "var5"]: - assert pd.Series( - concatenated[var] - == type(data2[var].variable.data)._concat_same_type( - [ - data1[var].variable.data, - data2[var].variable.data, - ] - ) - ).all() # need to wrap in series because pyarrow bool does not support `all` + assert pd.Series( + concatenated[var] + == type(data2[var].variable.data)._concat_same_type( + [ + data1[var].variable.data, + data2[var].variable.data, + ] + ) + ).all() # need to wrap in series because pyarrow bool does not support `all` def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 7df6f16e18e..da93a028349 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -53,6 +53,7 @@ assert_no_warnings, has_dask, has_dask_ge_2025_1_0, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cupy, @@ -3079,14 +3080,17 @@ def test_propagate_attrs(self, func) -> None: @pytest.mark.parametrize( "fill_value,extension_array", [ - ("a", pd.Categorical([pd.NA, "a", "b"])), + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ pytest.param( 0, pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), - marks=requires_pyarrow, - ), - ], - ids=["categorical", "int64[pyarrow]"], + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], ) def test_fillna_extension_array(self, fill_value, extension_array) -> None: srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) @@ -3108,12 +3112,15 @@ def test_fillna_extension_array_bad_val(self) -> None: @pytest.mark.parametrize( "extension_array", [ - pd.Categorical([pd.NA, "a", "b"]), + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ pytest.param( - pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), marks=requires_pyarrow - ), - ], - ids=["categorical", "int64[pyarrow]"], + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], ) def test_dropna_extension_array(self, extension_array) -> None: srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9cb544b294f..533169cdf07 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -58,6 +58,7 @@ create_test_data, has_cftime, has_dask, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cftime, @@ -298,12 +299,14 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B b c b a c a c a - var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1 + var4 (dim1) category 32B b c b a c a c a{} Attributes: foo: bar""".format( data["dim3"].dtype, "ns", + "var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "", ) ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) @@ -1833,15 +1836,21 @@ def test_categorical_index_reindex(self) -> None: @pytest.mark.parametrize( "extension_array", [ - pd.Categorical( - ["foo", "bar", "baz"], - categories=["foo", "bar", "baz", "qux"], - ), pytest.param( - pd.array([1, 1, None], dtype="int64[pyarrow]"), marks=requires_pyarrow + pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux"], + ), + id="categorical", ), - ], - ids=["categorical", "int64[pyarrow]"], + ] + + [ + pytest.param( + pd.array([1, 1, None], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], ) def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( @@ -5462,14 +5471,17 @@ def test_dropna(self) -> None: @pytest.mark.parametrize( "fill_value,extension_array", [ - ("a", pd.Categorical([pd.NA, "a", "b"])), + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ pytest.param( 0, pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), - marks=requires_pyarrow, - ), - ], - ids=["categorical", "int64[pyarrow]"], + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], ) def test_fillna_extension_array(self, fill_value, extension_array) -> None: srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) @@ -5484,12 +5496,15 @@ def test_fillna_extension_array(self, fill_value, extension_array) -> None: @pytest.mark.parametrize( "extension_array", [ - pd.Categorical([pd.NA, "a", "b"]), + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ pytest.param( - pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), marks=requires_pyarrow - ), - ], - ids=["categorical", "int64[pyarrow]"], + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], ) def test_dropna_extension_array(self, extension_array) -> None: srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) From cdedf04aee9e3018737b9d11a6438b8a6b7f72ab Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 19:10:53 +0200 Subject: [PATCH 12/16] (fix): function name --- xarray/core/extension_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index cff9ae646f6..78a51b5ca87 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -52,7 +52,7 @@ def __extension_duck_array__concatenate( @implements(np.reshape) -def __extension_duck_array__concatenate( +def __extension_duck_array__reshape( arr: T_ExtensionArray, shape: tuple ) -> T_ExtensionArray: if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): From 651152d4c2a99fe97d7285dd7e71e6b682ac7857 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 19:11:37 +0200 Subject: [PATCH 13/16] (fix): `ignore` comment mypy --- xarray/tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 533169cdf07..5062dbc290b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1864,12 +1864,12 @@ def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> N pd.api.extensions.ExtensionArray, (ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values), ) - assert reindexed_cat.equals( + assert reindexed_cat.equals( # type: ignore[attr-defined] pd.array( [pd.NA, extension_array[1], extension_array[1]], dtype=extension_array.dtype, ) - ) # type: ignore[attr-defined] + ) @requires_pyarrow def test_extension_array_reindex_same(self) -> None: From f168040a9692c07c18fa12ac37f4c2315f1466a6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 19:35:21 +0200 Subject: [PATCH 14/16] (fix): explicit skipped line --- xarray/tests/test_dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 5062dbc290b..75cda652da4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -299,17 +299,20 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B b c b a c a c a{} + var4 (dim1) category 32B b c b a c a c a + {} Attributes: foo: bar""".format( data["dim3"].dtype, "ns", "var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" if has_pyarrow - else "", + else "SKIP_LINE", ) ) - actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) + actual = "\n".join( + x.rstrip() for x in repr(data).split("\n") if x != "SKIP_LINE" + ) assert expected == actual From 710404dea9a93377bde9da58b7a0fd45ce7a808b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 May 2025 19:51:36 +0200 Subject: [PATCH 15/16] (fix): repr test --- xarray/tests/test_dataset.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 75cda652da4..bb8a3ea8ab9 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -299,20 +299,17 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B b c b a c a c a - {} + var4 (dim1) category 32B b c b a c a c a{} Attributes: - foo: bar""".format( - data["dim3"].dtype, - "ns", - "var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" - if has_pyarrow - else "SKIP_LINE", - ) - ) - actual = "\n".join( - x.rstrip() for x in repr(data).split("\n") if x != "SKIP_LINE" + foo: bar""" + ).format( + data["dim3"].dtype, + "ns", + "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "", ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) assert expected == actual From 62aa485030d41f09f5d4ae87b65f892d46a07425 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 May 2025 16:41:48 +0200 Subject: [PATCH 16/16] (fix): add copy + deepcopy --- xarray/core/extension_array.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 78a51b5ca87..d721aac4c29 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, Generic, cast @@ -164,3 +165,11 @@ def __getattr__(self, attr: str) -> Any: # (which is apparently the first thing sought in copy.copy from the under-construction copied object), # which would cause a recursion error since `array` is not present on the object when it is being constructed during `__{deep}copy__`. return getattr(super().__getattribute__("array"), attr) + + def __copy__(self) -> PandasExtensionArray[T_ExtensionArray]: + return PandasExtensionArray(copy.copy(self.array)) + + def __deepcopy__( + self, memo: dict[int, Any] | None = None + ) -> PandasExtensionArray[T_ExtensionArray]: + return PandasExtensionArray(copy.deepcopy(self.array, memo=memo))