Skip to content

Commit c8a6740

Browse files
String dtype: allow string dtype for non-raw apply with numba engine (#59854)
* String dtype: allow string dtype for non-raw apply with numba engine * remove xfails * clean-up
1 parent 7e5282f commit c8a6740

File tree

4 files changed

+2
-11
lines changed

4 files changed

+2
-11
lines changed

pandas/core/_numba/extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
@contextmanager
5454
def set_numba_data(index: Index):
5555
numba_data = index._data
56-
if numba_data.dtype == object:
56+
if numba_data.dtype in (object, "string"):
57+
numba_data = np.asarray(numba_data)
5758
if not lib.is_string_array(numba_data):
5859
raise ValueError(
5960
"The numba engine only supports using string or numeric column names"

pandas/core/apply.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,12 +1172,7 @@ def apply_with_numba(self) -> dict[int, Any]:
11721172
from pandas.core._numba.extensions import set_numba_data
11731173

11741174
index = self.obj.index
1175-
if index.dtype == "string":
1176-
index = index.astype(object)
1177-
11781175
columns = self.obj.columns
1179-
if columns.dtype == "string":
1180-
columns = columns.astype(object)
11811176

11821177
# Convert from numba dict to regular dict
11831178
# Our isinstance checks in the df constructor don't pass for numbas typed dict

pandas/tests/apply/test_frame_apply.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_apply(float_frame, engine, request):
6565
assert result.index is float_frame.index
6666

6767

68-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
6968
@pytest.mark.parametrize("axis", [0, 1])
7069
@pytest.mark.parametrize("raw", [True, False])
7170
@pytest.mark.parametrize("nopython", [True, False])

pandas/tests/apply/test_numba.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
import pandas.util._test_decorators as td
75

86
import pandas as pd
@@ -20,7 +18,6 @@ def apply_axis(request):
2018
return request.param
2119

2220

23-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
2421
def test_numba_vs_python_noop(float_frame, apply_axis):
2522
func = lambda x: x
2623
result = float_frame.apply(func, engine="numba", axis=apply_axis)
@@ -43,7 +40,6 @@ def test_numba_vs_python_string_index():
4340
)
4441

4542

46-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
4743
def test_numba_vs_python_indexing():
4844
frame = DataFrame(
4945
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},

0 commit comments

Comments
 (0)