diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index ac74e6a8e5f77..65bd941800294 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -50,10 +50,20 @@ Notable bug fixes These are bug fixes that might have notable behavior changes. -.. _whatsnew_230.notable_bug_fixes.notable_bug_fix1: +.. _whatsnew_230.notable_bug_fixes.string_comparisons: -notable_bug_fix1 -^^^^^^^^^^^^^^^^ +Comparisons between different string dtypes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In previous versions, comparing Series of different string dtypes (e.g. ``pd.StringDtype("pyarrow", na_value=pd.NA)`` against ``pd.StringDtype("python", na_value=np.nan)``) would result in inconsistent resulting dtype or incorrectly raise. pandas will now use the hierarchy + + object < (python, NaN) < (pyarrow, NaN) < (python, NA) < (pyarrow, NA) + +in determining the result dtype when there are different string dtypes compared. Some examples: + +- When ``pd.StringDtype("pyarrow", na_value=pd.NA)`` is compared against any other string dtype, the result will always be ``boolean[pyarrow]``. +- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("pyarrow", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array. +- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("python", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array. .. _whatsnew_230.api_changes: diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d7187b57a69e4..0b90bcea35100 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -33,7 +33,6 @@ infer_dtype_from_scalar, ) from pandas.core.dtypes.common import ( - CategoricalDtype, is_array_like, is_bool_dtype, is_float_dtype, @@ -730,9 +729,7 @@ def __setstate__(self, state) -> None: def _cmp_method(self, other, op) -> ArrowExtensionArray: pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance( - other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) - ) or isinstance(getattr(other, "dtype", None), CategoricalDtype): + if isinstance(other, (ExtensionArray, np.ndarray, list)): try: result = pc_func(self._pa_array, self._box_pa(other)) except pa.ArrowNotImplementedError: diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index ac758d0ef093c..8048306df91a2 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -1015,7 +1015,30 @@ def searchsorted( return super().searchsorted(value=value, side=side, sorter=sorter) def _cmp_method(self, other, op): - from pandas.arrays import BooleanArray + from pandas.arrays import ( + ArrowExtensionArray, + BooleanArray, + ) + + if ( + isinstance(other, BaseStringArray) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + + if isinstance(other, ArrowExtensionArray): + if isinstance(other, BaseStringArray): + # pyarrow storage has priority over python storage + # (except if we have NA semantics and other not) + if not ( + self.dtype.na_value is libmissing.NA + and other.dtype.na_value is not libmissing.NA + ): + return NotImplemented + else: + return NotImplemented if isinstance(other, StringArray): other = other._ndarray diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index a39d64429d162..9668981df827b 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -473,6 +473,14 @@ def value_counts(self, dropna: bool = True) -> Series: return result def _cmp_method(self, other, op): + if ( + isinstance(other, (BaseStringArray, ArrowExtensionArray)) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + result = super()._cmp_method(other, op) if self.dtype.na_value is np.nan: if op == operator.ne: diff --git a/pandas/core/ops/invalid.py b/pandas/core/ops/invalid.py index 395db1617cb63..62aa79a881717 100644 --- a/pandas/core/ops/invalid.py +++ b/pandas/core/ops/invalid.py @@ -25,7 +25,7 @@ def invalid_comparison( left: ArrayLike, - right: ArrayLike | Scalar, + right: ArrayLike | list | Scalar, op: Callable[[Any, Any], bool], ) -> npt.NDArray[np.bool_]: """ diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 5670fad7e2f4f..736c0e1782fc0 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -10,10 +10,12 @@ from pandas._config import using_string_dtype +from pandas.compat import HAS_PYARROW from pandas.compat.pyarrow import ( pa_version_under12p0, pa_version_under19p0, ) +import pandas.util._test_decorators as td from pandas.core.dtypes.common import is_dtype_equal @@ -45,6 +47,25 @@ def cls(dtype): return dtype.construct_array_type() +def string_dtype_highest_priority(dtype1, dtype2): + if HAS_PYARROW: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("pyarrow", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + pd.StringDtype("pyarrow", na_value=pd.NA), + ] + else: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + ] + + h1 = DTYPE_HIERARCHY.index(dtype1) + h2 = DTYPE_HIERARCHY.index(dtype2) + return DTYPE_HIERARCHY[max(h1, h2)] + + def test_dtype_constructor(): pytest.importorskip("pyarrow") @@ -331,13 +352,18 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype): tm.assert_extension_array_equal(result, expected) -def test_comparison_methods_array(comparison_op, dtype): +def test_comparison_methods_array(comparison_op, dtype, dtype2): op_name = f"__{comparison_op.__name__}__" a = pd.array(["a", None, "c"], dtype=dtype) - other = [None, None, "c"] - result = getattr(a, op_name)(other) - if dtype.na_value is np.nan: + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan and dtype2.na_value is np.nan: if operator.ne == comparison_op: expected = np.array([True, True, False]) else: @@ -345,11 +371,56 @@ def test_comparison_methods_array(comparison_op, dtype): expected[-1] = getattr(other[-1], op_name)(a[-1]) tm.assert_numpy_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) + else: + max_dtype = string_dtype_highest_priority(dtype, dtype2) + if max_dtype.storage == "python": + expected_dtype = "boolean" + else: + expected_dtype = "bool[pyarrow]" + + expected = np.full(len(a), fill_value=None, dtype="object") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + expected = pd.array(expected, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_comparison_methods_array_arrow_extension(comparison_op, dtype2): + # Test pd.ArrowDtype(pa.string()) against other string arrays + import pyarrow as pa + + op_name = f"__{comparison_op.__name__}__" + dtype = pd.ArrowDtype(pa.string()) + a = pd.array(["a", None, "c"], dtype=dtype) + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + expected = pd.array([None, None, True], dtype="bool[pyarrow]") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_list(comparison_op, dtype): + op_name = f"__{comparison_op.__name__}__" + + a = pd.array(["a", None, "c"], dtype=dtype) + other = [None, None, "c"] + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan: if operator.ne == comparison_op: - expected = np.array([True, True, True]) + expected = np.array([True, True, False]) else: expected = np.array([False, False, False]) + expected[-1] = getattr(other[-1], op_name)(a[-1]) tm.assert_numpy_array_equal(result, expected) else: @@ -359,10 +430,6 @@ def test_comparison_methods_array(comparison_op, dtype): expected = pd.array(expected, dtype=expected_dtype) tm.assert_extension_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) - expected = pd.array([None, None, None], dtype=expected_dtype) - tm.assert_extension_array_equal(result, expected) - def test_constructor_raises(cls): if cls is pd.arrays.StringArray: diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 25129111180d6..96c014f549056 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -31,6 +31,7 @@ from pandas.api.types import is_string_dtype from pandas.core.arrays import ArrowStringArray from pandas.core.arrays.string_ import StringDtype +from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority from pandas.tests.extension import base @@ -202,10 +203,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): dtype = cast(StringDtype, tm.get_dtype(obj)) if op_name in ["__add__", "__radd__"]: cast_to = dtype + dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None + if isinstance(dtype_other, StringDtype): + cast_to = string_dtype_highest_priority(dtype, dtype_other) elif dtype.na_value is np.nan: cast_to = np.bool_ # type: ignore[assignment] elif dtype.storage == "pyarrow": - cast_to = "boolean[pyarrow]" # type: ignore[assignment] + cast_to = "bool[pyarrow]" # type: ignore[assignment] else: cast_to = "boolean" # type: ignore[assignment] return pointwise_result.astype(cast_to) @@ -236,10 +240,10 @@ def test_arith_series_with_array( if ( using_infer_string and all_arithmetic_operators == "__radd__" - and ( - (dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW) - ) + and dtype.na_value is pd.NA + and (HAS_PYARROW or dtype.storage == "pyarrow") ): + # TODO(infer_string) mark = pytest.mark.xfail( reason="The pointwise operation result will be inferred to " "string[nan, pyarrow], which does not match the input dtype"