Skip to content

Commit e27b0e7

Browse files
authored
PERF: Series.str.get_dummies for ArrowDtype(pa.string()) (#53655)
* PERF: Series.str.get_dummies for ArrowDtype(pa.string()) * whatsnew * typing
1 parent 8893c38 commit e27b0e7

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

doc/source/whatsnew/v2.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ Performance improvements
323323
- Performance improvement in :meth:`DataFrame.loc` when selecting rows and columns (:issue:`53014`)
324324
- Performance improvement in :meth:`Series.add` for pyarrow string and binary dtypes (:issue:`53150`)
325325
- Performance improvement in :meth:`Series.corr` and :meth:`Series.cov` for extension dtypes (:issue:`52502`)
326+
- Performance improvement in :meth:`Series.str.get_dummies` for pyarrow-backed strings (:issue:`53655`)
326327
- Performance improvement in :meth:`Series.str.get` for pyarrow-backed strings (:issue:`53152`)
327328
- Performance improvement in :meth:`Series.str.split` with ``expand=True`` for pyarrow-backed strings (:issue:`53585`)
328329
- Performance improvement in :meth:`Series.to_numpy` when dtype is a numpy float dtype and ``na_value`` is ``np.nan`` (:issue:`52430`)

pandas/core/arrays/arrow/array.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,17 +2239,19 @@ def _str_findall(self, pat: str, flags: int = 0):
22392239
return type(self)(pa.chunked_array(result))
22402240

22412241
def _str_get_dummies(self, sep: str = "|"):
2242-
split = pc.split_pattern(self._pa_array, sep).combine_chunks()
2243-
uniques = split.flatten().unique()
2242+
split = pc.split_pattern(self._pa_array, sep)
2243+
flattened_values = pc.list_flatten(split)
2244+
uniques = flattened_values.unique()
22442245
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
2245-
result_data = []
2246-
for lst in split.to_pylist():
2247-
if lst is None:
2248-
result_data.append([False] * len(uniques_sorted))
2249-
else:
2250-
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
2251-
result_data.append(res.to_pylist())
2252-
result = type(self)(pa.array(result_data))
2246+
lengths = pc.list_value_length(split).fill_null(0).to_numpy()
2247+
n_rows = len(self)
2248+
n_cols = len(uniques)
2249+
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
2250+
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
2251+
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
2252+
dummies[indices] = True
2253+
dummies = dummies.reshape((n_rows, n_cols))
2254+
result = type(self)(pa.array(list(dummies)))
22532255
return result, uniques_sorted.to_pylist()
22542256

22552257
def _str_index(self, sub: str, start: int = 0, end: int | None = None):

0 commit comments

Comments
 (0)