diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 1dd1b12d6ae95..1eb1a630056a2 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -2339,7 +2339,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype: ) if not string.endswith("[pyarrow]"): raise TypeError(f"'{string}' must end with '[pyarrow]'") - if string == "string[pyarrow]": + if string in ("string[pyarrow]", "str[pyarrow]"): # Ensure Registry.find skips ArrowDtype to use StringDtype instead raise TypeError("string[pyarrow] should be constructed by StringDtype") if pa_version_under10p1: diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 5a59617ce5bd3..fa48393dd183e 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -837,6 +837,26 @@ def test_pandas_dtype_string_dtypes(string_storage): assert result == pd.StringDtype(string_storage, na_value=pd.NA) +def test_pandas_dtype_string_dtype_alias_with_storage(): + with pytest.raises(TypeError, match="not understood"): + pandas_dtype("str[python]") + + with pytest.raises(TypeError, match="not understood"): + pandas_dtype("str[pyarrow]") + + result = pandas_dtype("string[python]") + assert result == pd.StringDtype("python", na_value=pd.NA) + + if HAS_PYARROW: + result = pandas_dtype("string[pyarrow]") + assert result == pd.StringDtype("pyarrow", na_value=pd.NA) + else: + with pytest.raises( + ImportError, match="required for PyArrow backed StringArray" + ): + pandas_dtype("string[pyarrow]") + + @td.skip_if_installed("pyarrow") def test_construct_from_string_without_pyarrow_installed(): # GH 57928 diff --git a/pandas/tests/strings/test_get_dummies.py b/pandas/tests/strings/test_get_dummies.py index 3b989e284ca25..541b0ea150ba6 100644 --- a/pandas/tests/strings/test_get_dummies.py +++ b/pandas/tests/strings/test_get_dummies.py @@ -6,6 +6,7 @@ import pandas.util._test_decorators as td from pandas import ( + ArrowDtype, DataFrame, Index, MultiIndex, @@ -113,8 +114,10 @@ def test_get_dummies_with_str_dtype(any_string_dtype): # GH#47872 @td.skip_if_no("pyarrow") def test_get_dummies_with_pa_str_dtype(any_string_dtype): + import pyarrow as pa + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) - result = s.str.get_dummies("|", dtype="str[pyarrow]") + result = s.str.get_dummies("|", dtype=ArrowDtype(pa.string())) expected = DataFrame( [ ["true", "true", "false"], @@ -122,6 +125,6 @@ def test_get_dummies_with_pa_str_dtype(any_string_dtype): ["false", "false", "false"], ], columns=list("abc"), - dtype="str[pyarrow]", + dtype=ArrowDtype(pa.string()), ) tm.assert_frame_equal(result, expected)