Skip to content

ENH: add ExtensionArray._explode method; adjust pyarrow extension for use of new interface #54834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/reference/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ objects.

api.extensions.ExtensionArray._accumulate
api.extensions.ExtensionArray._concat_same_type
api.extensions.ExtensionArray._explode
api.extensions.ExtensionArray._formatter
api.extensions.ExtensionArray._from_factorized
api.extensions.ExtensionArray._from_sequence
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
-

Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,10 @@ def _explode(self):
"""
See Series.explode.__doc__.
"""
# child class explode method supports only list types; return
# default implementation for non list types.
if not pa.types.is_list(self.dtype.pyarrow_dtype):
return super()._explode()
values = self
counts = pa.compute.list_value_length(values._pa_array)
counts = counts.fill_null(1).to_numpy()
Expand Down
36 changes: 36 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ExtensionArray:
view
_accumulate
_concat_same_type
_explode
_formatter
_from_factorized
_from_sequence
Expand Down Expand Up @@ -1924,6 +1925,41 @@ def _hash_pandas_object(
values, encoding=encoding, hash_key=hash_key, categorize=categorize
)

def _explode(self) -> tuple[Self, npt.NDArray[np.uint64]]:
"""
Transform each element of list-like to a row.

For arrays that do not contain list-like elements the default
implementation of this method just returns a copy and an array
of ones (unchanged index).

Returns
-------
ExtensionArray
Array with the exploded values.
np.ndarray[uint64]
The original lengths of each list-like for determining the
resulting index.

See Also
--------
Series.explode : The method on the ``Series`` object that this
extension array method is meant to support.

Examples
--------
>>> import pyarrow as pa
>>> a = pd.array([[1, 2, 3], [4], [5, 6]],
... dtype=pd.ArrowDtype(pa.list_(pa.int64())))
>>> a._explode()
(<ArrowExtensionArray>
[1, 2, 3, 4, 5, 6]
Length: 6, dtype: int64[pyarrow], array([3, 1, 2], dtype=int32))
"""
values = self.copy()
counts = np.ones(shape=(len(self),), dtype=np.uint64)
return values, counts

def tolist(self) -> list:
"""
Return a list of the values.
Expand Down
7 changes: 2 additions & 5 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
ExtensionDtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -4390,7 +4387,7 @@ def explode(self, ignore_index: bool = False) -> Series:
3 4
dtype: object
"""
if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list:
if isinstance(self.dtype, ExtensionDtype):
values, counts = self._values._explode()
elif len(self) and is_object_dtype(self.dtype):
values, counts = reshape.explode(np.asarray(self._values))
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/series/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,13 @@ def test_explode_pyarrow_list_type(ignore_index):
dtype=pd.ArrowDtype(pa.int64()),
)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("ignore_index", [True, False])
def test_explode_pyarrow_non_list_type(ignore_index):
pa = pytest.importorskip("pyarrow")
data = [1, 2, 3]
ser = pd.Series(data, dtype=pd.ArrowDtype(pa.int64()))
result = ser.explode(ignore_index=ignore_index)
expected = pd.Series([1, 2, 3], dtype="int64[pyarrow]", index=[0, 1, 2])
tm.assert_series_equal(result, expected)