Skip to content

[ArrowStringArray] PERF: str.partition object fallback #41507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import pandas._libs.lib as lib
from pandas._typing import FrameOrSeriesUnion
from pandas.util._decorators import Appender
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.common import (
ensure_object,
is_bool_dtype,
is_categorical_dtype,
is_integer,
is_list_like,
is_object_dtype,
is_re,
)
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -266,13 +266,7 @@ def _wrap_result(
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif (
expand is True
and is_object_dtype(result)
and not isinstance(self._orig, ABCIndex)
):
# required when expand=True is explicitly specified
# not needed when inferred
elif expand in ("split", "rsplit") and not isinstance(self._orig, ABCIndex):

def cons_row(x):
if is_list_like(x):
Expand All @@ -288,8 +282,10 @@ def cons_row(x):
x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result
]

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")
elif expand in ("partition", "rpartition") and not isinstance(
self._orig, ABCIndex
):
result = list(result)

if expand is False:
# if expand is False, result should have the same name
Expand Down Expand Up @@ -775,14 +771,20 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"):
@Appender(_shared_docs["str_split"] % {"side": "beginning", "method": "split"})
@forbid_nonstring_types(["bytes"])
def split(self, pat=None, n=-1, expand=False):
validate_bool_kwarg(expand, "expand", none_allowed=False)
result = self._data.array._str_split(pat, n, expand)
return self._wrap_result(result, returns_string=expand, expand=expand)
return self._wrap_result(
result, returns_string=expand, expand="split" if expand else False
)

@Appender(_shared_docs["str_split"] % {"side": "end", "method": "rsplit"})
@forbid_nonstring_types(["bytes"])
def rsplit(self, pat=None, n=-1, expand=False):
validate_bool_kwarg(expand, "expand", none_allowed=False)
result = self._data.array._str_rsplit(pat, n=n)
return self._wrap_result(result, expand=expand, returns_string=expand)
return self._wrap_result(
result, expand="rsplit" if expand else False, returns_string=expand
)

_shared_docs[
"str_partition"
Expand Down Expand Up @@ -877,8 +879,11 @@ def rsplit(self, pat=None, n=-1, expand=False):
)
@forbid_nonstring_types(["bytes"])
def partition(self, sep=" ", expand=True):
validate_bool_kwarg(expand, "expand", none_allowed=False)
result = self._data.array._str_partition(sep, expand)
return self._wrap_result(result, expand=expand, returns_string=expand)
return self._wrap_result(
result, expand="partition" if expand else False, returns_string=expand
)

@Appender(
_shared_docs["str_partition"]
Expand All @@ -891,8 +896,11 @@ def partition(self, sep=" ", expand=True):
)
@forbid_nonstring_types(["bytes"])
def rpartition(self, sep=" ", expand=True):
validate_bool_kwarg(expand, "expand", none_allowed=False)
result = self._data.array._str_rpartition(sep, expand)
return self._wrap_result(result, expand=expand, returns_string=expand)
return self._wrap_result(
result, expand="rpartition" if expand else False, returns_string=expand
)

def get(self, i):
"""
Expand Down
14 changes: 11 additions & 3 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,12 @@ def test_split_to_dataframe(any_string_dtype):
)
tm.assert_frame_equal(result, exp)

with pytest.raises(ValueError, match="expand must be"):
s.str.split("_", expand="not_a_boolean")

def test_split_expand_kwarg_raises(any_string_dtype):
ser = Series([], dtype=any_string_dtype)
msg = 'For argument "expand" expected type bool, received type str'
with pytest.raises(ValueError, match=msg):
ser.str.split("_", expand="not_a_boolean")


def test_split_to_multiindex_expand():
Expand Down Expand Up @@ -274,7 +278,11 @@ def test_split_to_multiindex_expand():
tm.assert_index_equal(result, exp)
assert result.nlevels == 6

with pytest.raises(ValueError, match="expand must be"):

def test_split_index_expand_kwarg_raises():
idx = Index(["some_unequal_splits", "one_of_these_things_is_not", np.nan, None])
msg = 'For argument "expand" expected type bool, received type str'
with pytest.raises(ValueError, match=msg):
idx.str.split("_", expand="not_a_boolean")


Expand Down