Skip to content

Improve support for pandas Extension Arrays (#10301) #10380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
if is_extension_array_dtype(dtype):
return dtype, dtype.na_value
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
Expand Down Expand Up @@ -222,19 +224,51 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
return xp.isdtype(dtype, kind)


def preprocess_types(t):
if isinstance(t, str | bytes):
return type(t)
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and (
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
):
def maybe_promote_to_variable_width(
array_or_dtype: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.typing.ArrayLike | np.typing.DTypeLike:
if isinstance(array_or_dtype, str | bytes):
return type(array_or_dtype)
elif isinstance(
dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype
) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)):
# drop the length from numpy's fixed-width string dtypes, it is better to
# recalculate
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this
# for us
return dtype.type
else:
return t
return array_or_dtype
Comment on lines -225 to +241
Copy link
Author

@richard-berg richard-berg May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff looks ugly, but it's simply renaming the fn + its argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this does not promote to variable width strings (which would be the new string dtype, numpy.dtypes.StringDType) but rather drops the width of a existing fixed-width string dtype to force numpy to recalculate the width. The aim is to avoid truncating a python string object.

Additionally, it might be good to keep the "type" in the name of the function, since it only operates on string dtypes.



def should_promote_to_object(
arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, xp
) -> bool:
"""
Test whether the given arrays_and_dtypes, when evaluated individually, match the
type promotion rules found in PROMOTE_TO_OBJECT.
"""
np_result_types = set()
for arr_or_dtype in arrays_and_dtypes:
try:
result_type = array_api_compat.result_type(
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
)
if isinstance(result_type, np.dtype):
np_result_types.add(result_type)
except TypeError:
# passing individual objects to xp.result_type means NEP-18 implementations won't have
# a chance to intercept special values (such as NA) that numpy core cannot handle
pass
Comment on lines +259 to +262
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only this except block is new. The rest of this fn was lifted as-is from result_type() below.


if np_result_types:
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in np_result_types) and any(
np.issubdtype(t, right) for t in np_result_types
):
return True

return False


def result_type(
Expand Down Expand Up @@ -263,19 +297,9 @@ def result_type(
if xp is None:
xp = get_array_namespace(arrays_and_dtypes)

types = {
array_api_compat.result_type(preprocess_types(t), xp=xp)
for t in arrays_and_dtypes
}
if any(isinstance(t, np.dtype) for t in types):
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return np.dtype(object)
if should_promote_to_object(arrays_and_dtypes, xp):
return np.dtype(object)

return array_api_compat.result_type(
*map(preprocess_types, arrays_and_dtypes), xp=xp
*map(maybe_promote_to_variable_width, arrays_and_dtypes), xp=xp
)
34 changes: 19 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from xarray.compat import dask_array_compat, dask_array_ops
from xarray.compat.array_api_compat import get_array_namespace
from xarray.core import dtypes, nputils
from xarray.core.extension_array import (
PandasExtensionArray,
as_extension_array,
is_scalar,
)
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray.parallelcompat import get_chunked_array_type
Expand Down Expand Up @@ -239,7 +244,14 @@ def astype(data, dtype, *, xp=None, **kwargs):


def asarray(data, xp=np, dtype=None):
converted = data if is_duck_array(data) else xp.asarray(data)
if is_duck_array(data):
converted = data
elif is_extension_array_dtype(dtype):
# data may or may not be an ExtensionArray, so we can't rely on
# np.asarray to call our NEP-18 handler; gotta hook it ourselves
converted = PandasExtensionArray(as_extension_array(data, dtype))
Copy link
Contributor

@ilan-gold ilan-gold Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
converted = PandasExtensionArray(as_extension_array(data, dtype))
converted = PandasExtensionArray(np.asarray(data, dtype))

Isn't this what as_extension_array is meant for really i.e., to be used by the above suggested API?

else:
converted = xp.asarray(data, dtype=dtype)

if dtype is None or converted.dtype == dtype:
return converted
Expand All @@ -252,19 +264,6 @@ def asarray(data, xp=np, dtype=None):

def as_shared_dtype(scalars_or_arrays, xp=None):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
raise ValueError(
"Cannot cast arrays to shared type, found"
f" array types {[x.dtype for x in scalars_or_arrays]}"
)

Comment on lines -255 to -267
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we now provide an array-api implementation of np.result_type, we no longer need these special cases. (which were far too special, IMO; the cases where we raised ValueError are perfectly valid)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on, it's great work!

Can you extract out any non-EA changes to duck_array_ops.py and dtypes.py and make a separate PR please? It will be far easier to review then

# Avoid calling array_type("cupy") repeatidely in the any check
array_type_cupy = array_type("cupy")
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
Expand Down Expand Up @@ -384,7 +383,12 @@ def where(condition, x, y):
else:
condition = astype(condition, dtype=dtype, xp=xp)

return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp)

# pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved
maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y

return xp.where(condition, promoted_x, maybe_promoted_y)


def where_method(data, cond, other=dtypes.NA):
Expand Down
Loading