Skip to content

(fix): pandas extension array repr for int64[pyarrow] #10317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def __extension_duck_array__where(
return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)


@implements(np.reshape)
def __extension_duck_array__reshape(
arr: T_ExtensionArray, shape: tuple
) -> T_ExtensionArray:
if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
return arr
raise NotImplementedError(
f"Cannot reshape 1d-only pandas extension array to: {shape}"
)


@dataclass(frozen=True)
class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
"""NEP-18 compliant wrapper for pandas extension arrays.
Expand Down Expand Up @@ -101,10 +112,10 @@ def replace_duck_with_extension_array(args) -> list:

args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
raise KeyError("Function not registered for pandas extension arrays.")
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
return type(self)[type(res)](res)
return PandasExtensionArray(res)
return res

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down
10 changes: 9 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.treenode import group_subtrees
Expand Down Expand Up @@ -176,6 +177,11 @@ def format_timedelta(t, timedelta_format=None):

def format_item(x, timedelta_format=None, quote_strings=True):
"""Returns a succinct summary of an object as a string"""
if isinstance(x, PandasExtensionArray):
# We want to bypass PandasExtensionArray's repr here
# because its __repr__ is PandasExtensionArray(array=[...])
# and this function is only for single elements.
return str(x.array[0])
if isinstance(x, np.datetime64 | datetime):
return format_timestamp(x)
if isinstance(x, np.timedelta64 | timedelta):
Expand All @@ -194,7 +200,9 @@ def format_items(x):
"""Returns a succinct summaries of all items in a sequence as strings"""
x = to_duck_array(x)
timedelta_format = "datetime"
if np.issubdtype(x.dtype, np.timedelta64):
if not isinstance(x, PandasExtensionArray) and np.issubdtype(
x.dtype, np.timedelta64
):
x = astype(x, dtype="timedelta64[ns]")
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
time_needed = x[~pd.isnull(x)] != day_part
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ def create_test_data(
)
),
)
if has_pyarrow:
obj["var5"] = (
"dim1",
pd.array(
rs.integers(1, 10, size=dim_sizes[0]).tolist(),
dtype="int64[pyarrow]",
),
)
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
else:
Expand Down
16 changes: 9 additions & 7 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
assert_equal,
assert_identical,
requires_dask,
requires_pyarrow,
)
from xarray.tests.test_dataset import create_test_data

Expand Down Expand Up @@ -154,19 +155,20 @@ def test_concat_missing_var() -> None:
assert_identical(actual, expected)


def test_concat_categorical() -> None:
@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)])
def test_concat_extension_array(var) -> None:
data1 = create_test_data(use_extension_array=True)
data2 = create_test_data(use_extension_array=True)
concatenated = concat([data1, data2], dim="dim1")
assert (
concatenated["var4"]
== type(data2["var4"].variable.data)._concat_same_type(
assert pd.Series(
concatenated[var]
== type(data2[var].variable.data)._concat_same_type(
[
data1["var4"].variable.data,
data2["var4"].variable.data,
data1[var].variable.data,
data2[var].variable.data,
]
)
).all()
).all() # need to wrap in series because pyarrow bool does not support `all`


def test_concat_missing_multiple_consecutive_var() -> None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3637,7 +3637,7 @@ def test_series_categorical_index(self) -> None:

s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc")))
arr = DataArray(s)
assert "'a'" in repr(arr) # should not error
assert "a a b b" in repr(arr) # should not error

@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("data", ["list", "array", True])
Expand Down
26 changes: 15 additions & 11 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
create_test_data,
has_cftime,
has_dask,
has_pyarrow,
raise_if_dask_computes,
requires_bottleneck,
requires_cftime,
Expand Down Expand Up @@ -283,26 +284,28 @@ def test_repr(self) -> None:
data = create_test_data(seed=123, use_extension_array=True)
data.attrs["foo"] = "bar"
# need to insert str dtype at runtime to handle different endianness
var5 = (
"\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1"
if has_pyarrow
else ""
)
expected = dedent(
"""\
f"""\
<xarray.Dataset> Size: 2kB
Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8)
Coordinates:
* dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
* dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
* time (time) datetime64[{}] 160B 2000-01-01 2000-01-02 ... 2000-01-20
* dim3 (dim3) {data["dim3"].dtype} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
* time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20
numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3
Dimensions without coordinates: dim1
Data variables:
var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364
var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423
var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555
var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a'
var4 (dim1) category 32B b c b a c a c a{var5}
Attributes:
foo: bar""".format(
data["dim3"].dtype,
"ns",
)
foo: bar"""
)
actual = "\n".join(x.rstrip() for x in repr(data).split("\n"))

Expand Down Expand Up @@ -5884,20 +5887,21 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None:
def test_reduce_non_numeric(self) -> None:
data1 = create_test_data(seed=44, use_extension_array=True)
data2 = create_test_data(seed=44)
add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]}
add_vars = {"var6": ["dim1", "dim2"], "var7": ["dim1"]}
for v, dims in sorted(add_vars.items()):
size = tuple(data1.sizes[d] for d in dims)
data = np.random.randint(0, 100, size=size).astype(np.str_)
data1[v] = (dims, data, {"foo": "variable"})
# var4 is extension array categorical and should be dropped
# var4 and var5 are extension arrays and should be dropped
assert (
"var4" not in data1.mean()
and "var5" not in data1.mean()
and "var6" not in data1.mean()
and "var7" not in data1.mean()
)
assert_equal(data1.mean(), data2.mean())
assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1"))
assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2")
assert "var6" not in data1.mean(dim="dim2") and "var7" in data1.mean(dim="dim2")

@pytest.mark.filterwarnings(
"ignore:Once the behaviour of DataArray:DeprecationWarning"
Expand Down
Loading