Skip to content

CLN: cleanup strings._wrap_result #12487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 39 additions & 49 deletions pandas/core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def str_extract(arr, pat, flags=0, expand=None):
return _str_extract_frame(arr._orig, pat, flags=flags)
else:
result, name = _str_extract_noexpand(arr._data, pat, flags=flags)
return arr._wrap_result(result, name=name)
return arr._wrap_result(result, name=name, expand=expand)


def str_extractall(arr, pat, flags=0):
Expand Down Expand Up @@ -1292,7 +1292,10 @@ def __iter__(self):
i += 1
g = self.get(i)

def _wrap_result(self, result, use_codes=True, name=None):
def _wrap_result(self, result, use_codes=True,
name=None, expand=None):

from pandas.core.index import Index, MultiIndex

# for category, we do the stuff on the categories, so blow it up
# to the full series again
Expand All @@ -1302,48 +1305,42 @@ def _wrap_result(self, result, use_codes=True, name=None):
if use_codes and self._is_categorical:
result = take_1d(result, self._orig.cat.codes)

# leave as it is to keep extract and get_dummies results
# can be merged to _wrap_result_expand in v0.17
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index

if not hasattr(result, 'ndim'):
if not hasattr(result, 'ndim') or not hasattr(result, 'dtype'):
return result
assert result.ndim < 3

if result.ndim == 1:
# Wait until we are sure result is a Series or Index before
# checking attributes (GH 12180)
name = name or getattr(result, 'name', None) or self._orig.name
if isinstance(self._orig, Index):
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if is_bool_dtype(result):
return result
return Index(result, name=name)
return Series(result, index=self._orig.index, name=name)
else:
assert result.ndim < 3
return DataFrame(result, index=self._orig.index)
if expand is None:
# infer from ndim if expand is not specified
expand = False if result.ndim == 1 else True

elif expand is True and not isinstance(self._orig, Index):
# required when expand=True is explicitly specified
# not needed when infered

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

result = [cons_row(x) for x in result]

def _wrap_result_expand(self, result, expand=False):
if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

# for category, we do the stuff on the categories, so blow it up
# to the full series again
if self._is_categorical:
result = take_1d(result, self._orig.cat.codes)

from pandas.core.index import Index, MultiIndex
if not hasattr(result, 'ndim'):
return result
if name is None:
name = getattr(result, 'name', None)
if name is None:
# do not use logical or, _orig may be a DataFrame
# which has "name" column
name = self._orig.name

# Wait until we are sure result is a Series or Index before
# checking attributes (GH 12180)
if isinstance(self._orig, Index):
name = getattr(result, 'name', None)
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if hasattr(result, 'dtype') and is_bool_dtype(result):
if is_bool_dtype(result):
return result

if expand:
Expand All @@ -1354,18 +1351,10 @@ def _wrap_result_expand(self, result, expand=False):
else:
index = self._orig.index
if expand:

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

cons = self._orig._constructor_expanddim
data = [cons_row(x) for x in result]
return cons(data, index=index)
return cons(result, index=index)
else:
name = getattr(result, 'name', None)
# Must a Series
cons = self._orig._constructor
return cons(result, name=name, index=index)

Expand All @@ -1380,12 +1369,12 @@ def cat(self, others=None, sep=None, na_rep=None):
@copy(str_split)
def split(self, pat=None, n=-1, expand=False):
result = str_split(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_rsplit)
def rsplit(self, pat=None, n=-1, expand=False):
result = str_rsplit(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

_shared_docs['str_partition'] = ("""
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
Expand Down Expand Up @@ -1440,7 +1429,7 @@ def rsplit(self, pat=None, n=-1, expand=False):
def partition(self, pat=' ', expand=True):
f = lambda x: x.partition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@Appender(_shared_docs['str_partition'] % {
'side': 'last',
Expand All @@ -1451,7 +1440,7 @@ def partition(self, pat=' ', expand=True):
def rpartition(self, pat=' ', expand=True):
f = lambda x: x.rpartition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_get)
def get(self, i):
Expand Down Expand Up @@ -1597,7 +1586,8 @@ def get_dummies(self, sep='|'):
# methods available for making the dummies...
data = self._orig.astype(str) if self._is_categorical else self._data
result = str_get_dummies(data, sep)
return self._wrap_result(result, use_codes=(not self._is_categorical))
return self._wrap_result(result, use_codes=(not self._is_categorical),
expand=True)

@copy(str_translate)
def translate(self, table, deletechars=None):
Expand Down