From 45d79b56c72268e5ff1a3cf4c54571b18707e136 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 19 Aug 2024 14:15:27 -0700 Subject: [PATCH 1/4] REF: de-duplicate arrow string methods --- pandas/core/arrays/string_arrow.py | 46 ++++++------------------------ 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 67114815341b6..163b33405cf39 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -358,7 +358,7 @@ def _str_repeat(self, repeats: int | Sequence[int]): if not isinstance(repeats, int): return super()._str_repeat(repeats) else: - return type(self)(pc.binary_repeat(self._pa_array, repeats)) + return ArrowExtensionArray._str_repeat(self, repeats=repeats) def _str_match( self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None @@ -379,13 +379,7 @@ def _str_slice( ) -> Self: if stop is None: return super()._str_slice(start, stop, step) - if start is None: - start = 0 - if step is None: - step = 1 - return type(self)( - pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) - ) + return ArrowExtensionArray._str_slice(self, start=start, stop=stop, step=step) def _str_isalnum(self): result = pc.utf8_is_alnum(self._pa_array) @@ -427,39 +421,15 @@ def _str_len(self): result = pc.utf8_length(self._pa_array) return self._convert_int_dtype(result) - def _str_lower(self) -> Self: - return type(self)(pc.utf8_lower(self._pa_array)) - - def _str_upper(self) -> Self: - return type(self)(pc.utf8_upper(self._pa_array)) - - def _str_strip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_trim_whitespace(self._pa_array) - else: - result = pc.utf8_trim(self._pa_array, characters=to_strip) - return type(self)(result) - - def _str_lstrip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_ltrim_whitespace(self._pa_array) - else: - result = pc.utf8_ltrim(self._pa_array, characters=to_strip) - return type(self)(result) - - def _str_rstrip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_rtrim_whitespace(self._pa_array) - else: - result = pc.utf8_rtrim(self._pa_array, characters=to_strip) - return type(self)(result) + _str_lower = ArrowExtensionArray._str_lower + _str_upper = ArrowExtensionArray._str_upper + _str_strip = ArrowExtensionArray._str_strip + _str_lstrip = ArrowExtensionArray._str_lstrip + _str_rstrip = ArrowExtensionArray._str_rstrip def _str_removeprefix(self, prefix: str): if not pa_version_under13p0: - starts_with = pc.starts_with(self._pa_array, pattern=prefix) - removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix)) - result = pc.if_else(starts_with, removed, self._pa_array) - return type(self)(result) + return ArrowExtensionArray._str_removeprefix(self, prefix) return super()._str_removeprefix(prefix) def _str_removesuffix(self, suffix: str): From 8fbce0bc058ecc12af49df5484eebbf7df673ef0 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 19 Aug 2024 14:37:41 -0700 Subject: [PATCH 2/4] REF: de-duplicate ArrowStringArray methods --- pandas/core/arrays/string_arrow.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 163b33405cf39..045140d18ac03 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -360,20 +360,6 @@ def _str_repeat(self, repeats: int | Sequence[int]): else: return ArrowExtensionArray._str_repeat(self, repeats=repeats) - def _str_match( - self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None - ): - if not pat.startswith("^"): - pat = f"^{pat}" - return self._str_contains(pat, case, flags, na, regex=True) - - def _str_fullmatch( - self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None - ): - if not pat.endswith("$") or pat.endswith("\\$"): - pat = f"{pat}$" - return self._str_match(pat, case, flags, na) - def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None ) -> Self: @@ -421,23 +407,20 @@ def _str_len(self): result = pc.utf8_length(self._pa_array) return self._convert_int_dtype(result) + _str_match = ArrowExtensionArray._str_match + _str_fullmatch = ArrowExtensionArray._str_fullmatch _str_lower = ArrowExtensionArray._str_lower _str_upper = ArrowExtensionArray._str_upper _str_strip = ArrowExtensionArray._str_strip _str_lstrip = ArrowExtensionArray._str_lstrip _str_rstrip = ArrowExtensionArray._str_rstrip + _str_removesuffix = ArrowStringArrayMixin._str_removesuffix def _str_removeprefix(self, prefix: str): if not pa_version_under13p0: return ArrowExtensionArray._str_removeprefix(self, prefix) return super()._str_removeprefix(prefix) - def _str_removesuffix(self, suffix: str): - ends_with = pc.ends_with(self._pa_array, pattern=suffix) - removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix)) - result = pc.if_else(ends_with, removed, self._pa_array) - return type(self)(result) - def _str_count(self, pat: str, flags: int = 0): if flags: return super()._str_count(pat, flags) From 6f7bec0fc2f73fa33de2802a73a6facc4b20db94 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 19 Aug 2024 14:53:29 -0700 Subject: [PATCH 3/4] REF: de-duplicate ArrowStringArray methods (2) --- pandas/core/arrays/_arrow_string_mixins.py | 48 ++++++++++++++ pandas/core/arrays/arrow/array.py | 38 +++-------- pandas/core/arrays/string_arrow.py | 74 ++++------------------ 3 files changed, 69 insertions(+), 91 deletions(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 06c74290bd82e..d8daee6771771 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -23,6 +23,54 @@ class ArrowStringArrayMixin: def __init__(self, *args, **kwargs) -> None: raise NotImplementedError + def _result_converter(self, result: pa.Array, na=None): + # Convert bool-dtype results to the appropriate output type + raise NotImplementedError + + def _str_isalnum(self) -> Self: + result = pc.utf8_is_alnum(self._pa_array) + return self._result_converter(result) + + def _str_isalpha(self): + result = pc.utf8_is_alpha(self._pa_array) + return self._result_converter(result) + + def _str_isdecimal(self): + result = pc.utf8_is_decimal(self._pa_array) + return self._result_converter(result) + + def _str_isdigit(self): + result = pc.utf8_is_digit(self._pa_array) + return self._result_converter(result) + + def _str_islower(self): + result = pc.utf8_is_lower(self._pa_array) + return self._result_converter(result) + + def _str_isnumeric(self): + result = pc.utf8_is_numeric(self._pa_array) + return self._result_converter(result) + + def _str_isspace(self): + result = pc.utf8_is_space(self._pa_array) + return self._result_converter(result) + + def _str_istitle(self): + result = pc.utf8_is_title(self._pa_array) + return self._result_converter(result) + + def _str_isupper(self): + result = pc.utf8_is_upper(self._pa_array) + return self._result_converter(result) + + def _convert_int_dtype(self, result): + # Convert int-dtype results to the appropriate output type + raise NotImplementedError + + def _str_len(self): + result = pc.utf8_length(self._pa_array) + return self._convert_int_dtype(result) + def _str_pad( self, width: int, diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index e95fa441e18fb..482f4ff9a7ca9 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1972,7 +1972,7 @@ def _rank( """ See Series.rank.__doc__. """ - return type(self)( + return self._convert_int_dtype( self._rank_calc( axis=axis, method=method, @@ -2288,7 +2288,14 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]: def _str_count(self, pat: str, flags: int = 0) -> Self: if flags: raise NotImplementedError(f"count not implemented with {flags=}") - return type(self)(pc.count_substring_regex(self._pa_array, pat)) + result = pc.count_substring_regex(self._pa_array, pat) + return self._convert_int_dtype(result) + + def _result_converter(self, result, na=None): + return type(self)(result) + + def _convert_int_dtype(self, result): + return type(self)(result) def _str_contains( self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True @@ -2441,33 +2448,6 @@ def _str_slice( pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) ) - def _str_isalnum(self) -> Self: - return type(self)(pc.utf8_is_alnum(self._pa_array)) - - def _str_isalpha(self) -> Self: - return type(self)(pc.utf8_is_alpha(self._pa_array)) - - def _str_isdecimal(self) -> Self: - return type(self)(pc.utf8_is_decimal(self._pa_array)) - - def _str_isdigit(self) -> Self: - return type(self)(pc.utf8_is_digit(self._pa_array)) - - def _str_islower(self) -> Self: - return type(self)(pc.utf8_is_lower(self._pa_array)) - - def _str_isnumeric(self) -> Self: - return type(self)(pc.utf8_is_numeric(self._pa_array)) - - def _str_isspace(self) -> Self: - return type(self)(pc.utf8_is_space(self._pa_array)) - - def _str_istitle(self) -> Self: - return type(self)(pc.utf8_is_title(self._pa_array)) - - def _str_isupper(self) -> Self: - return type(self)(pc.utf8_is_upper(self._pa_array)) - def _str_len(self) -> Self: return type(self)(pc.utf8_length(self._pa_array)) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 045140d18ac03..9d813674edb00 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -52,7 +52,6 @@ from pandas._typing import ( ArrayLike, - AxisInt, Dtype, Scalar, Self, @@ -367,45 +366,17 @@ def _str_slice( return super()._str_slice(start, stop, step) return ArrowExtensionArray._str_slice(self, start=start, stop=stop, step=step) - def _str_isalnum(self): - result = pc.utf8_is_alnum(self._pa_array) - return self._result_converter(result) - - def _str_isalpha(self): - result = pc.utf8_is_alpha(self._pa_array) - return self._result_converter(result) + _str_isalnum = ArrowStringArrayMixin._str_isalnum + _str_isalpha = ArrowStringArrayMixin._str_isalpha + _str_isdecimal = ArrowStringArrayMixin._str_isdecimal + _str_isdigit = ArrowStringArrayMixin._str_isdigit + _str_islower = ArrowStringArrayMixin._str_islower + _str_isnumeric = ArrowStringArrayMixin._str_isnumeric + _str_isspace = ArrowStringArrayMixin._str_isspace + _str_istitle = ArrowStringArrayMixin._str_istitle + _str_isupper = ArrowStringArrayMixin._str_isupper - def _str_isdecimal(self): - result = pc.utf8_is_decimal(self._pa_array) - return self._result_converter(result) - - def _str_isdigit(self): - result = pc.utf8_is_digit(self._pa_array) - return self._result_converter(result) - - def _str_islower(self): - result = pc.utf8_is_lower(self._pa_array) - return self._result_converter(result) - - def _str_isnumeric(self): - result = pc.utf8_is_numeric(self._pa_array) - return self._result_converter(result) - - def _str_isspace(self): - result = pc.utf8_is_space(self._pa_array) - return self._result_converter(result) - - def _str_istitle(self): - result = pc.utf8_is_title(self._pa_array) - return self._result_converter(result) - - def _str_isupper(self): - result = pc.utf8_is_upper(self._pa_array) - return self._result_converter(result) - - def _str_len(self): - result = pc.utf8_length(self._pa_array) - return self._convert_int_dtype(result) + _str_len = ArrowStringArrayMixin._str_len _str_match = ArrowExtensionArray._str_match _str_fullmatch = ArrowExtensionArray._str_fullmatch @@ -424,8 +395,7 @@ def _str_removeprefix(self, prefix: str): def _str_count(self, pat: str, flags: int = 0): if flags: return super()._str_count(pat, flags) - result = pc.count_substring_regex(self._pa_array, pat) - return self._convert_int_dtype(result) + return ArrowExtensionArray._str_count(self, pat, flags) def _str_find(self, sub: str, start: int = 0, end: int | None = None): if start != 0 and end is not None: @@ -481,27 +451,7 @@ def _reduce( else: return result - def _rank( - self, - *, - axis: AxisInt = 0, - method: str = "average", - na_option: str = "keep", - ascending: bool = True, - pct: bool = False, - ): - """ - See Series.rank.__doc__. - """ - return self._convert_int_dtype( - self._rank_calc( - axis=axis, - method=method, - na_option=na_option, - ascending=ascending, - pct=pct, - ) - ) + _rank = ArrowExtensionArray._rank def value_counts(self, dropna: bool = True) -> Series: result = super().value_counts(dropna=dropna) From bd4a83b353f52e31df632b4439317fea488baa1b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 19 Aug 2024 15:03:50 -0700 Subject: [PATCH 4/4] REF: remove redundant _str_len --- pandas/core/arrays/arrow/array.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 482f4ff9a7ca9..5d7ff3c73a44f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2448,9 +2448,6 @@ def _str_slice( pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) ) - def _str_len(self) -> Self: - return type(self)(pc.utf8_length(self._pa_array)) - def _str_lower(self) -> Self: return type(self)(pc.utf8_lower(self._pa_array))