diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index f1631a5ea9e..0e77bc5f409 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -66,6 +66,17 @@ def __extension_duck_array__where( return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) +@implements(np.reshape) +def __extension_duck_array__reshape( + arr: T_ExtensionArray, shape: tuple +) -> T_ExtensionArray: + if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): + return arr + raise NotImplementedError( + f"Cannot reshape 1d-only pandas extension array to: {shape}" + ) + + @dataclass(frozen=True) class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): """NEP-18 compliant wrapper for pandas extension arrays. @@ -101,10 +112,10 @@ def replace_duck_with_extension_array(args) -> list: args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) + raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): - return type(self)[type(res)](res) + return PandasExtensionArray(res) return res def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 756ca0b9803..fb672997706 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -19,6 +19,7 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -176,6 +177,11 @@ def format_timedelta(t, timedelta_format=None): def format_item(x, timedelta_format=None, quote_strings=True): """Returns a succinct summary of an object as a string""" + if isinstance(x, PandasExtensionArray): + # We want to bypass PandasExtensionArray's repr here + # because its __repr__ is PandasExtensionArray(array=[...]) + # and this function is only for single elements. + return str(x.array[0]) if isinstance(x, np.datetime64 | datetime): return format_timestamp(x) if isinstance(x, np.timedelta64 | timedelta): @@ -194,7 +200,9 @@ def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" x = to_duck_array(x) timedelta_format = "datetime" - if np.issubdtype(x.dtype, np.timedelta64): + if not isinstance(x, PandasExtensionArray) and np.issubdtype( + x.dtype, np.timedelta64 + ): x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b33192393f7..fe76df75fa0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,6 +363,14 @@ def create_test_data( ) ), ) + if has_pyarrow: + obj["var5"] = ( + "dim1", + pd.array( + rs.integers(1, 10, size=dim_sizes[0]).tolist(), + dtype="int64[pyarrow]", + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 49c6490d819..ed5aac4fe99 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -21,6 +21,7 @@ assert_equal, assert_identical, requires_dask, + requires_pyarrow, ) from xarray.tests.test_dataset import create_test_data @@ -154,19 +155,20 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -def test_concat_categorical() -> None: +@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)]) +def test_concat_extension_array(var) -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") - assert ( - concatenated["var4"] - == type(data2["var4"].variable.data)._concat_same_type( + assert pd.Series( + concatenated[var] + == type(data2[var].variable.data)._concat_same_type( [ - data1["var4"].variable.data, - data2["var4"].variable.data, + data1[var].variable.data, + data2[var].variable.data, ] ) - ).all() + ).all() # need to wrap in series because pyarrow bool does not support `all` def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e7acdcdd4f3..3640bf4df14 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3637,7 +3637,7 @@ def test_series_categorical_index(self) -> None: s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) arr = DataArray(s) - assert "'a'" in repr(arr) # should not error + assert "a a b b" in repr(arr) # should not error @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("data", ["list", "array", True]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2ec2b20cac5..aaef95d6e82 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -60,6 +60,7 @@ create_test_data, has_cftime, has_dask, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cftime, @@ -283,26 +284,28 @@ def test_repr(self) -> None: data = create_test_data(seed=123, use_extension_array=True) data.attrs["foo"] = "bar" # need to insert str dtype at runtime to handle different endianness + var5 = ( + "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "" + ) expected = dedent( - """\ + f"""\ Size: 2kB Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) Coordinates: * dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 - * dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' - * time (time) datetime64[{}] 160B 2000-01-01 2000-01-02 ... 2000-01-20 + * dim3 (dim3) {data["dim3"].dtype} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20 numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' + var4 (dim1) category 32B b c b a c a c a{var5} Attributes: - foo: bar""".format( - data["dim3"].dtype, - "ns", - ) + foo: bar""" ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) @@ -5884,20 +5887,21 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: def test_reduce_non_numeric(self) -> None: data1 = create_test_data(seed=44, use_extension_array=True) data2 = create_test_data(seed=44) - add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} + add_vars = {"var6": ["dim1", "dim2"], "var7": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - # var4 is extension array categorical and should be dropped + # var4 and var5 are extension arrays and should be dropped assert ( "var4" not in data1.mean() and "var5" not in data1.mean() and "var6" not in data1.mean() + and "var7" not in data1.mean() ) assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") + assert "var6" not in data1.mean(dim="dim2") and "var7" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning"