-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improve support for pandas Extension Arrays (#10301) #10380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: | |
# N.B. these casting rules should match pandas | ||
dtype_: np.typing.DTypeLike | ||
fill_value: Any | ||
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): | ||
if is_extension_array_dtype(dtype): | ||
return dtype, dtype.na_value | ||
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): | ||
# for now, we always promote string dtypes to object for consistency with existing behavior | ||
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes | ||
dtype_ = object | ||
|
@@ -222,19 +224,51 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: | |
return xp.isdtype(dtype, kind) | ||
|
||
|
||
def preprocess_types(t): | ||
if isinstance(t, str | bytes): | ||
return type(t) | ||
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and ( | ||
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_) | ||
): | ||
def maybe_promote_to_variable_width( | ||
array_or_dtype: np.typing.ArrayLike | np.typing.DTypeLike, | ||
) -> np.typing.ArrayLike | np.typing.DTypeLike: | ||
if isinstance(array_or_dtype, str | bytes): | ||
return type(array_or_dtype) | ||
elif isinstance( | ||
dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype | ||
) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)): | ||
# drop the length from numpy's fixed-width string dtypes, it is better to | ||
# recalculate | ||
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this | ||
# for us | ||
return dtype.type | ||
else: | ||
return t | ||
return array_or_dtype | ||
|
||
|
||
def should_promote_to_object( | ||
arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, xp | ||
) -> bool: | ||
""" | ||
Test whether the given arrays_and_dtypes, when evaluated individually, match the | ||
type promotion rules found in PROMOTE_TO_OBJECT. | ||
""" | ||
np_result_types = set() | ||
for arr_or_dtype in arrays_and_dtypes: | ||
try: | ||
result_type = array_api_compat.result_type( | ||
maybe_promote_to_variable_width(arr_or_dtype), xp=xp | ||
) | ||
if isinstance(result_type, np.dtype): | ||
np_result_types.add(result_type) | ||
except TypeError: | ||
# passing individual objects to xp.result_type means NEP-18 implementations won't have | ||
# a chance to intercept special values (such as NA) that numpy core cannot handle | ||
pass | ||
Comment on lines
+259
to
+262
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only this |
||
|
||
if np_result_types: | ||
for left, right in PROMOTE_TO_OBJECT: | ||
if any(np.issubdtype(t, left) for t in np_result_types) and any( | ||
np.issubdtype(t, right) for t in np_result_types | ||
): | ||
return True | ||
|
||
return False | ||
|
||
|
||
def result_type( | ||
|
@@ -263,19 +297,9 @@ def result_type( | |
if xp is None: | ||
xp = get_array_namespace(arrays_and_dtypes) | ||
|
||
types = { | ||
array_api_compat.result_type(preprocess_types(t), xp=xp) | ||
for t in arrays_and_dtypes | ||
} | ||
if any(isinstance(t, np.dtype) for t in types): | ||
# only check if there's numpy dtypes – the array API does not | ||
# define the types we're checking for | ||
for left, right in PROMOTE_TO_OBJECT: | ||
if any(np.issubdtype(t, left) for t in types) and any( | ||
np.issubdtype(t, right) for t in types | ||
): | ||
return np.dtype(object) | ||
if should_promote_to_object(arrays_and_dtypes, xp): | ||
return np.dtype(object) | ||
|
||
return array_api_compat.result_type( | ||
*map(preprocess_types, arrays_and_dtypes), xp=xp | ||
*map(maybe_promote_to_variable_width, arrays_and_dtypes), xp=xp | ||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -27,6 +27,11 @@ | |||||
from xarray.compat import dask_array_compat, dask_array_ops | ||||||
from xarray.compat.array_api_compat import get_array_namespace | ||||||
from xarray.core import dtypes, nputils | ||||||
from xarray.core.extension_array import ( | ||||||
PandasExtensionArray, | ||||||
as_extension_array, | ||||||
is_scalar, | ||||||
) | ||||||
from xarray.core.options import OPTIONS | ||||||
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available | ||||||
from xarray.namedarray.parallelcompat import get_chunked_array_type | ||||||
|
@@ -239,7 +244,14 @@ def astype(data, dtype, *, xp=None, **kwargs): | |||||
|
||||||
|
||||||
def asarray(data, xp=np, dtype=None): | ||||||
converted = data if is_duck_array(data) else xp.asarray(data) | ||||||
if is_duck_array(data): | ||||||
converted = data | ||||||
elif is_extension_array_dtype(dtype): | ||||||
# data may or may not be an ExtensionArray, so we can't rely on | ||||||
# np.asarray to call our NEP-18 handler; gotta hook it ourselves | ||||||
converted = PandasExtensionArray(as_extension_array(data, dtype)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Isn't this what |
||||||
else: | ||||||
converted = xp.asarray(data, dtype=dtype) | ||||||
|
||||||
if dtype is None or converted.dtype == dtype: | ||||||
return converted | ||||||
|
@@ -252,19 +264,6 @@ def asarray(data, xp=np, dtype=None): | |||||
|
||||||
def as_shared_dtype(scalars_or_arrays, xp=None): | ||||||
"""Cast a arrays to a shared dtype using xarray's type promotion rules.""" | ||||||
if any(is_extension_array_dtype(x) for x in scalars_or_arrays): | ||||||
extension_array_types = [ | ||||||
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) | ||||||
] | ||||||
if len(extension_array_types) == len(scalars_or_arrays) and all( | ||||||
isinstance(x, type(extension_array_types[0])) for x in extension_array_types | ||||||
): | ||||||
return scalars_or_arrays | ||||||
raise ValueError( | ||||||
"Cannot cast arrays to shared type, found" | ||||||
f" array types {[x.dtype for x in scalars_or_arrays]}" | ||||||
) | ||||||
|
||||||
Comment on lines
-255
to
-267
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we now provide an array-api implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for taking this on, it's great work! Can you extract out any non-EA changes to |
||||||
# Avoid calling array_type("cupy") repeatidely in the any check | ||||||
array_type_cupy = array_type("cupy") | ||||||
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): | ||||||
|
@@ -384,7 +383,12 @@ def where(condition, x, y): | |||||
else: | ||||||
condition = astype(condition, dtype=dtype, xp=xp) | ||||||
|
||||||
return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) | ||||||
promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp) | ||||||
|
||||||
# pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved | ||||||
maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y | ||||||
|
||||||
return xp.where(condition, promoted_x, maybe_promoted_y) | ||||||
|
||||||
|
||||||
def where_method(data, cond, other=dtypes.NA): | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This diff looks ugly, but it's simply renaming the fn + its argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that this does not promote to variable width strings (which would be the new string dtype,
numpy.dtypes.StringDType
) but rather drops the width of a existing fixed-width string dtype to forcenumpy
to recalculate the width. The aim is to avoid truncating a python string object.Additionally, it might be good to keep the "type" in the name of the function, since it only operates on string dtypes.