From 289cd208c23e7333f241af3b9ac68461554896bf Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Sat, 21 Sep 2024 00:35:48 +0200 Subject: [PATCH 1/3] String dtype: allow string dtype for non-raw apply with numba engine --- pandas/core/_numba/extensions.py | 3 ++- pandas/core/apply.py | 8 ++++---- pandas/tests/apply/test_frame_apply.py | 6 ------ 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pandas/core/_numba/extensions.py b/pandas/core/_numba/extensions.py index e6f0427de2a3a..413fdafc7fd04 100644 --- a/pandas/core/_numba/extensions.py +++ b/pandas/core/_numba/extensions.py @@ -53,7 +53,8 @@ @contextmanager def set_numba_data(index: Index): numba_data = index._data - if numba_data.dtype == object: + if numba_data.dtype in (object, "string"): + numba_data = np.asarray(numba_data) if not lib.is_string_array(numba_data): raise ValueError( "The numba engine only supports using string or numeric column names" diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 5959156d11123..5447ae87fa864 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1172,12 +1172,12 @@ def apply_with_numba(self) -> dict[int, Any]: from pandas.core._numba.extensions import set_numba_data index = self.obj.index - if index.dtype == "string": - index = index.astype(object) + # if index.dtype == "string": + # index = index.astype(object) columns = self.obj.columns - if columns.dtype == "string": - columns = columns.astype(object) + # if columns.dtype == "string": + # columns = columns.astype(object) # Convert from numba dict to regular dict # Our isinstance checks in the df constructor don't pass for numbas typed dict diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 3be3562d23cd6..27f55de6b86ac 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -6,8 +6,6 @@ from pandas._config import using_string_dtype -from pandas.compat import HAS_PYARROW - from pandas.core.dtypes.dtypes import CategoricalDtype import pandas as pd @@ -65,7 +63,6 @@ def test_apply(float_frame, engine, request): assert result.index is float_frame.index -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("raw", [True, False]) @pytest.mark.parametrize("nopython", [True, False]) @@ -1247,9 +1244,6 @@ def test_agg_multiple_mixed(): tm.assert_frame_equal(result, expected) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" -) def test_agg_multiple_mixed_raises(): # GH 20909 mdf = DataFrame( From 676c28f696a44dfc4c09a524f7c9feca96c4eb38 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Sat, 21 Sep 2024 09:38:14 +0200 Subject: [PATCH 2/3] remove xfails --- pandas/tests/apply/test_numba.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pandas/tests/apply/test_numba.py b/pandas/tests/apply/test_numba.py index 825d295043e69..d6cd9c321ace6 100644 --- a/pandas/tests/apply/test_numba.py +++ b/pandas/tests/apply/test_numba.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - import pandas.util._test_decorators as td import pandas as pd @@ -20,7 +18,6 @@ def apply_axis(request): return request.param -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_numba_vs_python_noop(float_frame, apply_axis): func = lambda x: x result = float_frame.apply(func, engine="numba", axis=apply_axis) @@ -43,7 +40,6 @@ def test_numba_vs_python_string_index(): ) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_numba_vs_python_indexing(): frame = DataFrame( {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]}, From 1d99b348f73ae44e7e03676106d8dd9e49ae3f21 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Sat, 21 Sep 2024 09:41:27 +0200 Subject: [PATCH 3/3] clean-up --- pandas/core/apply.py | 5 ----- pandas/tests/apply/test_frame_apply.py | 5 +++++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 5447ae87fa864..7d50b466f5126 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1172,12 +1172,7 @@ def apply_with_numba(self) -> dict[int, Any]: from pandas.core._numba.extensions import set_numba_data index = self.obj.index - # if index.dtype == "string": - # index = index.astype(object) - columns = self.obj.columns - # if columns.dtype == "string": - # columns = columns.astype(object) # Convert from numba dict to regular dict # Our isinstance checks in the df constructor don't pass for numbas typed dict diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 27f55de6b86ac..dee0efcd8fd15 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -6,6 +6,8 @@ from pandas._config import using_string_dtype +from pandas.compat import HAS_PYARROW + from pandas.core.dtypes.dtypes import CategoricalDtype import pandas as pd @@ -1244,6 +1246,9 @@ def test_agg_multiple_mixed(): tm.assert_frame_equal(result, expected) +@pytest.mark.xfail( + using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" +) def test_agg_multiple_mixed_raises(): # GH 20909 mdf = DataFrame(