Skip to content

(fix): no fill_value on reindex #10304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,21 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(["a", "b", "c"]),
pd.array([1, 2, 3], dtype="int64"),
pd.array(["a", "b", "c"], dtype="string"),
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
],
)
def test_roundtrip_1d_pandas_extension_array(extension_array) -> None:
df = pd.DataFrame({"arr": extension_array})
arr = xr.Dataset.from_dataframe(df)["arr"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
assert (df["arr"] == roundtripped).all()
assert df["arr"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
5 changes: 4 additions & 1 deletion xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.compat import array_api_compat, npcompat
Expand Down Expand Up @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
if pd.api.types.is_extension_array_dtype(dtype):
return dtype, pd.NA
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
Expand Down
30 changes: 23 additions & 7 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,35 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
non_nans = [x for x in scalars_or_arrays if not isna(x)]
if len(extension_array_types) == len(non_nans) and all(
non_nans_or_scalar = [
x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x))
]
if len(extension_array_types) == len(non_nans_or_scalar) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return [
# Get the extension array class of the first element, guaranteed to be the same
# as the others thanks to the anove check.
extension_array_class = type(
non_nans_or_scalar[0].array
if isinstance(non_nans_or_scalar[0], PandasExtensionArray)
else non_nans_or_scalar[0]
)
# Cast scalars/nans to extension array class
arrays_with_nan_to_sequence = [
x
if not isna(x)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
if not (isna(x) or np.isscalar(x))
else extension_array_class._from_sequence(
[x], dtype=non_nans_or_scalar[0].dtype
)
for x in scalars_or_arrays
]
# Wrap the output if necessary
return [
PandasExtensionArray(x)
if not isinstance(x, PandasExtensionArray)
else x
for x in arrays_with_nan_to_sequence
]
raise ValueError(
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)
Expand Down Expand Up @@ -407,7 +424,6 @@ def where(condition, x, y):
condition = asarray(condition, dtype=dtype, xp=xp)
else:
condition = astype(condition, dtype=dtype, xp=xp)

return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


Expand Down
19 changes: 17 additions & 2 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def __extension_duck_array__concatenate(
return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined]


@implements(np.reshape)
def __extension_duck_array__reshape(
arr: T_ExtensionArray, shape: tuple
) -> T_ExtensionArray:
if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
return arr
raise NotImplementedError(
f"Cannot reshape 1d-only pandas extension array to: {shape}"
)


@implements(np.where)
def __extension_duck_array__where(
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
Expand Down Expand Up @@ -101,16 +112,20 @@ def replace_duck_with_extension_array(args) -> list:

args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
raise KeyError("Function not registered for pandas extension arrays.")
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
return type(self)[type(res)](res)
return PandasExtensionArray(res)
return res

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return ufunc(*inputs, **kwargs)

def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
if (
isinstance(key, tuple) and len(key) == 1
): # pyarrow type arrays can't handle since-length tuples
key = key[0]
item = self.array[key]
if is_extension_array_dtype(item):
return PandasExtensionArray(item)
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.treenode import group_subtrees
Expand Down Expand Up @@ -176,6 +177,8 @@ def format_timedelta(t, timedelta_format=None):

def format_item(x, timedelta_format=None, quote_strings=True):
"""Returns a succinct summary of an object as a string"""
if isinstance(x, PandasExtensionArray):
return f"{x.array[0]}"
if isinstance(x, np.datetime64 | datetime):
return format_timestamp(x)
if isinstance(x, np.timedelta64 | timedelta):
Expand All @@ -194,7 +197,9 @@ def format_items(x):
"""Returns a succinct summaries of all items in a sequence as strings"""
x = to_duck_array(x)
timedelta_format = "datetime"
if np.issubdtype(x.dtype, np.timedelta64):
if not isinstance(x, PandasExtensionArray) and np.issubdtype(
x.dtype, np.timedelta64
):
x = astype(x, dtype="timedelta64[ns]")
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
time_needed = x[~pd.isnull(x)] != day_part
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ def create_test_data(
)
),
)
if has_pyarrow:
obj["var5"] = (
"dim1",
pd.array(
rs.integers(1, 10, size=dim_sizes[0]).tolist(),
dtype="int64[pyarrow]",
),
)
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
else:
Expand Down
16 changes: 9 additions & 7 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
assert_equal,
assert_identical,
requires_dask,
requires_pyarrow,
)
from xarray.tests.test_dataset import create_test_data

Expand Down Expand Up @@ -154,19 +155,20 @@ def test_concat_missing_var() -> None:
assert_identical(actual, expected)


def test_concat_categorical() -> None:
@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)])
def test_concat_extension_array(var) -> None:
data1 = create_test_data(use_extension_array=True)
data2 = create_test_data(use_extension_array=True)
concatenated = concat([data1, data2], dim="dim1")
assert (
concatenated["var4"]
== type(data2["var4"].variable.data)._concat_same_type(
assert pd.Series(
concatenated[var]
== type(data2[var].variable.data)._concat_same_type(
[
data1["var4"].variable.data,
data2["var4"].variable.data,
data1[var].variable.data,
data2[var].variable.data,
]
)
).all()
).all() # need to wrap in series because pyarrow bool does not support `all`


def test_concat_missing_multiple_consecutive_var() -> None:
Expand Down
56 changes: 55 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
assert_no_warnings,
has_dask,
has_dask_ge_2025_1_0,
has_pyarrow,
raise_if_dask_computes,
requires_bottleneck,
requires_cupy,
Expand All @@ -61,6 +62,7 @@
requires_iris,
requires_numexpr,
requires_pint,
requires_pyarrow,
requires_scipy,
requires_sparse,
source_ndarray,
Expand Down Expand Up @@ -3075,6 +3077,58 @@ def test_propagate_attrs(self, func) -> None:
with set_options(keep_attrs=True):
assert func(da).attrs == da.attrs

@pytest.mark.parametrize(
"fill_value,extension_array",
[
pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"),
]
+ [
pytest.param(
0,
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"),
id="int64[pyarrow]",
)
]
if has_pyarrow
else [],
)
def test_fillna_extension_array(self, fill_value, extension_array) -> None:
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
filled = da.fillna(fill_value)
assert filled.dtype == srs.dtype
assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all()

@requires_pyarrow
def test_fillna_extension_array_bad_val(self) -> None:
srs: pd.Series = pd.Series(
index=np.array([1, 2, 3]),
data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"),
)
da = srs.to_xarray()
with pytest.raises(ValueError):
da.fillna("a")

@pytest.mark.parametrize(
"extension_array",
[
pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"),
]
+ [
pytest.param(
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]"
)
]
if has_pyarrow
else [],
)
def test_dropna_extension_array(self, extension_array) -> None:
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
filled = da.dropna("index")
assert filled.dtype == srs.dtype
assert (filled.values == srs.values[1:]).all()

def test_fillna(self) -> None:
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
actual = a.fillna(-1)
Expand Down Expand Up @@ -3637,7 +3691,7 @@ def test_series_categorical_index(self) -> None:

s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc")))
arr = DataArray(s)
assert "'a'" in repr(arr) # should not error
assert "a a b b c" in repr(arr) # should not error

@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("data", ["list", "array", True])
Expand Down
Loading
Loading