Skip to content

API: add EA._from_scalars / stricter casting of result values back to EA dtype #38315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _reconstruct_data(
if isinstance(values, cls) and values.dtype == dtype:
return values

values = cls._from_sequence(values)
values = cls._from_scalars(values, dtype=dtype)
elif is_bool_dtype(dtype):
values = values.astype(dtype, copy=False)

Expand Down
16 changes: 11 additions & 5 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ class ExtensionArray:
# Constructors
# ------------------------------------------------------------------------

@classmethod
def _from_scalars(cls, data, dtype):
if not all(isinstance(v, dtype.type) or isna(v) for v in data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isna -> is_valid_nat_for_dtype? (still need to rename to is_valid_na_for_dtype)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been renamed to the clearer is_valid_na_for_dtype

raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
"""
Expand Down Expand Up @@ -688,7 +694,7 @@ def fillna(self, value=None, method=None, limit=None):
if method is not None:
func = get_fill_func(method)
new_values = func(self.astype(object), limit=limit, mask=mask)
new_values = self._from_sequence(new_values, dtype=self.dtype)
new_values = self._from_scalars(new_values, dtype=self.dtype)
else:
# fill with value
new_values = self.copy()
Expand Down Expand Up @@ -750,7 +756,7 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
if isna(fill_value):
fill_value = self.dtype.na_value

empty = self._from_sequence(
empty = self._from_scalars(
[fill_value] * min(abs(periods), len(self)), dtype=self.dtype
)
if periods > 0:
Expand All @@ -770,7 +776,7 @@ def unique(self):
uniques : ExtensionArray
"""
uniques = unique(self.astype(object))
return self._from_sequence(uniques, dtype=self.dtype)
return self._from_scalars(uniques, dtype=self.dtype)

def searchsorted(self, value, side="left", sorter=None):
"""
Expand Down Expand Up @@ -1080,7 +1086,7 @@ def take(self, indices, allow_fill=False, fill_value=None):

result = take(data, indices, fill_value=fill_value,
allow_fill=allow_fill)
return self._from_sequence(result, dtype=self.dtype)
return self._from_scalars(result, dtype=self.dtype)
"""
# Implementer note: The `fill_value` parameter should be a user-facing
# value, an instance of self.dtype.type. When passed `fill_value=None`,
Expand Down Expand Up @@ -1420,7 +1426,7 @@ def _maybe_convert(arr):
# https://github.com/pandas-dev/pandas/issues/22850
# We catch all regular exceptions here, and fall back
# to an ndarray.
res = maybe_cast_to_extension_array(type(self), arr)
res = maybe_cast_to_extension_array(type(self), arr, self.dtype)
if not isinstance(res, type(self)):
# exception raised in _from_sequence; ensure we have ndarray
res = np.asarray(arr)
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
def dtype(self) -> BooleanDtype:
return self._dtype

@classmethod
def _from_scalars(cls, data, dtype) -> BooleanArray:
# override because dtype.type is only the numpy scalar
if not all(isinstance(v, (bool, np.bool_)) or isna(v) for v in data):
raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ def _constructor(self) -> Type[Categorical]:
return Categorical

@classmethod
def _from_scalars(cls, data, dtype):
# if not all(
# isinstance(v, dtype.categories.dtype.type) or isna(v) for v in data
# ):
# raise TypeError("Requires dtype scalars")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not all (x in dtype.categories or is_valid_nat_for_dtype(x, dtype.categories.dtype) for x in data)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as mentioned somewhere in #38315 (comment), for categorical we probably want to check if the values are valid categories.

We might already have some functionality for this? Like the main constructor, but then raising an error instead of coercing unknown values to NaN:

In [17]: pd.Categorical(["a", "b", "c"], categories=["a", "b"])
Out[17]: 
['a', 'b', NaN]
Categories (2, object): ['a', 'b']

The above is basically done by _get_codes_for_values, so we might want a version of that which is strict instead of coercing to NaN.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Categorical._validate_setitem_value doest what you're describing

return cls._from_sequence(data, dtype=dtype)

def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
return Categorical(scalars, dtype=dtype, copy=copy)

Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def _simple_new(
result._dtype = dtype
return result

@classmethod
def _from_scalars(cls, data, dtype):
# override because dtype.type is not always Timestamp
if not all(isinstance(v, Timestamp) or isna(v) for v in data):
raise TypeError("Requires timestamp scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
return cls._from_sequence_not_strict(scalars, dtype=dtype, copy=copy)
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
)
super().__init__(values, mask, copy=copy)

@classmethod
def _from_scalars(cls, data, dtype):
# override because dtype.type is only the numpy scalar
if not all(isinstance(v, (float, dtype.type)) or isna(v) for v in data):
raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(
cls, scalars, *, dtype=None, copy: bool = False
Expand Down
10 changes: 10 additions & 0 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ def __pos__(self):
def __abs__(self):
return type(self)(np.abs(self._data), self._mask)

@classmethod
def _from_scalars(cls, data, dtype):
# override because dtype.type is only the numpy scalar
# TODO accept float here?
if not all(
isinstance(v, (int, dtype.type, float, np.float_)) or isna(v) for v in data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for floats require that v.is_integer()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The coerce_to_arrray function that is used by _from_sequence already checks for this as well (so we can pass through here any float, as it will be catched later)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTA/TDA/PA have a _recognized_scalars attribute that could be useful

):
raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def shift(self, periods: int = 1, fill_value: object = None) -> IntervalArray:
fill_value = Index(self._left, copy=False)._na_value
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1))
else:
empty = self._from_sequence([fill_value] * empty_len)
empty = self._from_scalars([fill_value] * empty_len, self.dtype)

if periods > 0:
a = empty
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def __init__(self, values: Union[np.ndarray, PandasArray], copy: bool = False):
self._ndarray = values
self._dtype = PandasDtype(values.dtype)

@classmethod
def _from_scalars(cls, data, dtype):
# doesn't work for object dtype
# if not all(isinstance(v, dtype.type) or isna(v) for v in data):
# raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
Expand Down
13 changes: 10 additions & 3 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ def maybe_cast_result(
"""
dtype = obj.dtype
dtype = maybe_cast_result_dtype(dtype, how)
# result_dtype = maybe_cast_result_dtype(dtype, how)
# if result_dtype is not None:
# # we know what the result dtypes needs to be -> be more permissive in casting
# # (eg ints with nans became floats)
# cls = result_dtype.construct_array_type()
# return cls._from_sequence(obj, dtype=result_dtype)
Comment on lines +348 to +353
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the first case mentioned (when we know the resulting dtype, we should maybe force the result back, instead of relying on strict casting).
So the above code is one possibility: we can change maybe_cast_result_dtype to only return a dtype if it knows what dtype should be and otherwise return None (instead of passing through the original type). This way, we can take a different path for "known" dtypes, vs when guessing the dtype.


assert not is_scalar(result)

Expand Down Expand Up @@ -395,19 +401,20 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
):
return Float64Dtype()
return dtype
# return None


def maybe_cast_to_extension_array(
cls: Type[ExtensionArray], obj: ArrayLike, dtype: Optional[ExtensionDtype] = None
) -> ArrayLike:
"""
Call to `_from_sequence` that returns the object unchanged on Exception.
Call to `_from_scalars` that returns the object unchanged on Exception.

Parameters
----------
cls : class, subclass of ExtensionArray
obj : arraylike
Values to pass to cls._from_sequence
Values to pass to cls._from_scalars
dtype : ExtensionDtype, optional

Returns
Expand All @@ -429,7 +436,7 @@ def maybe_cast_to_extension_array(
return obj

try:
result = cls._from_sequence(obj, dtype=dtype)
result = cls._from_scalars(obj, dtype=dtype)
except Exception:
# We can't predict what downstream EA constructors may raise
result = obj
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2951,7 +2951,7 @@ def transpose(self, *args, copy: bool = False) -> DataFrame:
arr_type = dtype.construct_array_type()
values = self.values

new_values = [arr_type._from_sequence(row, dtype=dtype) for row in values]
new_values = [arr_type._from_scalars(row, dtype=dtype) for row in values]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case we already know we have the correct types, so cant we go directly to from_sequence?

result = self._constructor(
dict(zip(self.index, new_values)), index=self.columns
)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def fast_xs(self, loc: int) -> ArrayLike:
result[rl] = blk.iget((i, loc))

if isinstance(dtype, ExtensionDtype):
result = dtype.construct_array_type()._from_sequence(result, dtype=dtype)
result = dtype.construct_array_type()._from_scalars(result, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too, shouldnt we know we have the correct types?


return result

Expand Down
4 changes: 3 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,7 +2913,9 @@ def combine(self, other, func, fill_value=None) -> Series:
# TODO: can we do this for only SparseDtype?
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
new_values = maybe_cast_to_extension_array(type(self._values), new_values)
new_values = maybe_cast_to_extension_array(
type(self._values), new_values, self.dtype
)
return self._constructor(new_values, index=new_index, name=new_name)

def combine_first(self, other) -> Series:
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def __init__(self, values, dtype=None, copy=False, context=None):
def dtype(self):
return self._dtype

@classmethod
def _from_scalars(cls, data, dtype):
# TODO not needed if we keep the base class method
if not all(isinstance(v, dtype.type) or pd.isna(v) for v in data):
raise TypeError("Requires dtype scalars")
return cls._from_sequence(data, dtype=dtype)

@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
return cls(scalars)
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/indexes/interval/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ def test_astype_object(self, index):

def test_astype_category(self, index):
result = index.astype("category")
expected = CategoricalIndex(index.values)
# TODO astype doesn't preserve the exact interval dtype (eg uint64)
# while the CategoricalIndex constructor does -> temporarily also
# here convert to object dtype numpy array.
# Once this is fixed, the commented code can be uncommented
# -> https://github.com/pandas-dev/pandas/issues/38316
# expected = CategoricalIndex(index.values)
expected = CategoricalIndex(np.asarray(index.values))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened an issue for this -> #38316 (it's also mentioned in the comment above)

tm.assert_index_equal(result, expected)
# assert result.dtype.categories.dtype == index.dtype

result = index.astype(CategoricalDtype())
tm.assert_index_equal(result, expected)
Expand Down