Skip to content

Commit 41d937d

Browse files
ENH: support addition with pyarrow string dtypes (#51338)
* ENH: support addition with pyarrow string dtypes * use pc.binary_join_element_wise * mypy fixup --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent bab4f30 commit 41d937d

File tree

4 files changed

+33
-28
lines changed

4 files changed

+33
-28
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4+
import operator
45
import re
56
from typing import (
67
TYPE_CHECKING,
@@ -50,6 +51,7 @@
5051
)
5152
from pandas.core.dtypes.missing import isna
5253

54+
from pandas.core import roperator
5355
from pandas.core.arraylike import OpsMixin
5456
from pandas.core.arrays.base import ExtensionArray
5557
import pandas.core.common as com
@@ -458,6 +460,29 @@ def _cmp_method(self, other, op):
458460
return BooleanArray(values, mask)
459461

460462
def _evaluate_op_method(self, other, op, arrow_funcs):
463+
pa_type = self._data.type
464+
if (pa.types.is_string(pa_type) or pa.types.is_binary(pa_type)) and op in [
465+
operator.add,
466+
roperator.radd,
467+
]:
468+
length = self._data.length()
469+
470+
seps: list[str] | list[bytes]
471+
if pa.types.is_string(pa_type):
472+
seps = [""] * length
473+
else:
474+
seps = [b""] * length
475+
476+
if is_scalar(other):
477+
other = [other] * length
478+
elif isinstance(other, type(self)):
479+
other = other._data
480+
if op is operator.add:
481+
result = pc.binary_join_element_wise(self._data, other, seps)
482+
else:
483+
result = pc.binary_join_element_wise(other, self._data, seps)
484+
return type(self)(result)
485+
461486
pc_func = arrow_funcs[op.__name__]
462487
if pc_func is NotImplemented:
463488
raise NotImplementedError(f"{op.__name__} not implemented.")

pandas/tests/arrays/string_/test_string.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,7 @@ def test_astype_roundtrip(dtype):
9696
tm.assert_series_equal(result, ser)
9797

9898

99-
def test_add(dtype, request):
100-
if dtype.storage == "pyarrow":
101-
reason = (
102-
"unsupported operand type(s) for +: 'ArrowStringArray' and "
103-
"'ArrowStringArray'"
104-
)
105-
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
106-
request.node.add_marker(mark)
107-
99+
def test_add(dtype):
108100
a = pd.Series(["a", "b", "c", None, None], dtype=dtype)
109101
b = pd.Series(["x", "y", None, "z", None], dtype=dtype)
110102

@@ -140,12 +132,7 @@ def test_add_2d(dtype, request):
140132
s + b
141133

142134

143-
def test_add_sequence(dtype, request):
144-
if dtype.storage == "pyarrow":
145-
reason = "unsupported operand type(s) for +: 'ArrowStringArray' and 'list'"
146-
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
147-
request.node.add_marker(mark)
148-
135+
def test_add_sequence(dtype):
149136
a = pd.array(["a", "b", None, None], dtype=dtype)
150137
other = ["x", None, "y", None]
151138

pandas/tests/extension/test_arrow.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,10 @@ def _get_scalar_exception(self, opname, pa_dtype):
10131013
exc = NotImplementedError
10141014
elif arrow_temporal_supported:
10151015
exc = None
1016+
elif opname in ["__add__", "__radd__"] and (
1017+
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
1018+
):
1019+
exc = None
10161020
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
10171021
exc = pa.ArrowNotImplementedError
10181022
else:
@@ -1187,9 +1191,7 @@ def test_add_series_with_extension_array(self, data, request):
11871191
return
11881192

11891193
if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
1190-
pa.types.is_binary(pa_dtype)
1191-
or pa.types.is_string(pa_dtype)
1192-
or pa.types.is_boolean(pa_dtype)
1194+
pa.types.is_boolean(pa_dtype)
11931195
):
11941196
request.node.add_marker(
11951197
pytest.mark.xfail(

pandas/tests/strings/test_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
MultiIndex,
77
Series,
88
_testing as tm,
9-
get_option,
109
)
1110
from pandas.core.strings.accessor import StringMethods
1211

@@ -124,16 +123,8 @@ def test_api_per_method(
124123
method(*args, **kwargs)
125124

126125

127-
def test_api_for_categorical(any_string_method, any_string_dtype, request):
126+
def test_api_for_categorical(any_string_method, any_string_dtype):
128127
# https://github.com/pandas-dev/pandas/issues/10661
129-
130-
if any_string_dtype == "string[pyarrow]" or (
131-
any_string_dtype == "string" and get_option("string_storage") == "pyarrow"
132-
):
133-
# unsupported operand type(s) for +: 'ArrowStringArray' and 'str'
134-
mark = pytest.mark.xfail(raises=NotImplementedError, reason="Not Implemented")
135-
request.node.add_marker(mark)
136-
137128
s = Series(list("aabb"), dtype=any_string_dtype)
138129
s = s + " " + s
139130
c = s.astype("category")

0 commit comments

Comments
 (0)