-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
API: add EA._from_scalars / stricter casting of result values back to EA dtype #38315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c8dc332
2bf2992
4ddfe1c
98eec57
5c4dc77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -426,6 +426,13 @@ def _constructor(self) -> Type[Categorical]: | |
return Categorical | ||
|
||
@classmethod | ||
def _from_scalars(cls, data, dtype): | ||
# if not all( | ||
# isinstance(v, dtype.categories.dtype.type) or isna(v) for v in data | ||
# ): | ||
# raise TypeError("Requires dtype scalars") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, as mentioned somewhere in #38315 (comment), for categorical we probably want to check if the values are valid categories. We might already have some functionality for this? Like the main constructor, but then raising an error instead of coercing unknown values to NaN:
The above is basically done by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Categorical._validate_setitem_value doest what you're describing |
||
return cls._from_sequence(data, dtype=dtype) | ||
|
||
def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False): | ||
return Categorical(scalars, dtype=dtype, copy=copy) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,16 @@ def __pos__(self): | |
def __abs__(self): | ||
return type(self)(np.abs(self._data), self._mask) | ||
|
||
@classmethod | ||
def _from_scalars(cls, data, dtype): | ||
# override because dtype.type is only the numpy scalar | ||
# TODO accept float here? | ||
if not all( | ||
isinstance(v, (int, dtype.type, float, np.float_)) or isna(v) for v in data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for floats require that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DTA/TDA/PA have a |
||
): | ||
raise TypeError("Requires dtype scalars") | ||
return cls._from_sequence(data, dtype=dtype) | ||
|
||
@classmethod | ||
def _from_sequence( | ||
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -345,6 +345,12 @@ def maybe_cast_result( | |
""" | ||
dtype = obj.dtype | ||
dtype = maybe_cast_result_dtype(dtype, how) | ||
# result_dtype = maybe_cast_result_dtype(dtype, how) | ||
# if result_dtype is not None: | ||
# # we know what the result dtypes needs to be -> be more permissive in casting | ||
# # (eg ints with nans became floats) | ||
# cls = result_dtype.construct_array_type() | ||
# return cls._from_sequence(obj, dtype=result_dtype) | ||
Comment on lines
+348
to
+353
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is related to the first case mentioned (when we know the resulting dtype, we should maybe force the result back, instead of relying on strict casting). |
||
|
||
assert not is_scalar(result) | ||
|
||
|
@@ -395,19 +401,20 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj: | |
): | ||
return Float64Dtype() | ||
return dtype | ||
# return None | ||
|
||
|
||
def maybe_cast_to_extension_array( | ||
cls: Type[ExtensionArray], obj: ArrayLike, dtype: Optional[ExtensionDtype] = None | ||
) -> ArrayLike: | ||
""" | ||
Call to `_from_sequence` that returns the object unchanged on Exception. | ||
Call to `_from_scalars` that returns the object unchanged on Exception. | ||
|
||
Parameters | ||
---------- | ||
cls : class, subclass of ExtensionArray | ||
obj : arraylike | ||
Values to pass to cls._from_sequence | ||
Values to pass to cls._from_scalars | ||
dtype : ExtensionDtype, optional | ||
|
||
Returns | ||
|
@@ -429,7 +436,7 @@ def maybe_cast_to_extension_array( | |
return obj | ||
|
||
try: | ||
result = cls._from_sequence(obj, dtype=dtype) | ||
result = cls._from_scalars(obj, dtype=dtype) | ||
except Exception: | ||
# We can't predict what downstream EA constructors may raise | ||
result = obj | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2951,7 +2951,7 @@ def transpose(self, *args, copy: bool = False) -> DataFrame: | |
arr_type = dtype.construct_array_type() | ||
values = self.values | ||
|
||
new_values = [arr_type._from_sequence(row, dtype=dtype) for row in values] | ||
new_values = [arr_type._from_scalars(row, dtype=dtype) for row in values] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this case we already know we have the correct types, so cant we go directly to from_sequence? |
||
result = self._constructor( | ||
dict(zip(self.index, new_values)), index=self.columns | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -957,7 +957,7 @@ def fast_xs(self, loc: int) -> ArrayLike: | |
result[rl] = blk.iget((i, loc)) | ||
|
||
if isinstance(dtype, ExtensionDtype): | ||
result = dtype.construct_array_type()._from_sequence(result, dtype=dtype) | ||
result = dtype.construct_array_type()._from_scalars(result, dtype=dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too, shouldnt we know we have the correct types? |
||
|
||
return result | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,8 +35,15 @@ def test_astype_object(self, index): | |
|
||
def test_astype_category(self, index): | ||
result = index.astype("category") | ||
expected = CategoricalIndex(index.values) | ||
# TODO astype doesn't preserve the exact interval dtype (eg uint64) | ||
# while the CategoricalIndex constructor does -> temporarily also | ||
# here convert to object dtype numpy array. | ||
# Once this is fixed, the commented code can be uncommented | ||
# -> https://github.com/pandas-dev/pandas/issues/38316 | ||
# expected = CategoricalIndex(index.values) | ||
expected = CategoricalIndex(np.asarray(index.values)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened an issue for this -> #38316 (it's also mentioned in the comment above) |
||
tm.assert_index_equal(result, expected) | ||
# assert result.dtype.categories.dtype == index.dtype | ||
|
||
result = index.astype(CategoricalDtype()) | ||
tm.assert_index_equal(result, expected) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isna -> is_valid_nat_for_dtype? (still need to rename to is_valid_na_for_dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has been renamed to the clearer is_valid_na_for_dtype