Skip to content

Commit 53c209d

Browse files
committed
BUG: Stack/unstack do not return subclassed objects (GH15563)
1 parent cdebcf3 commit 53c209d

File tree

5 files changed

+227
-20
lines changed

5 files changed

+227
-20
lines changed

doc/source/whatsnew/v0.23.0.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ Reshaping
345345
- Fixed construction of a :class:`Series` from a ``dict`` containing ``NaN`` as key (:issue:`18480`)
346346
- Bug in :func:`Series.rank` where ``Series`` containing ``NaT`` modifies the ``Series`` inplace (:issue:`18521`)
347347
- Bug in :func:`Dataframe.pivot_table` which fails when the ``aggfunc`` arg is of type string. The behavior is now consistent with other methods like ``agg`` and ``apply`` (:issue:`18713`)
348-
348+
- Bug in :func:`DataFrame.stack`, `DataFrame.unstack`, `Series.unstack` which were not returning subclasses (:issue:`15563`)
349+
-
349350

350351
Numeric
351352
^^^^^^^

pandas/core/reshape/melt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def melt(frame, id_vars=None, value_vars=None, var_name=None,
8080
mdata[col] = np.asanyarray(frame.columns
8181
._get_level_values(i)).repeat(N)
8282

83-
from pandas import DataFrame
84-
return DataFrame(mdata, columns=mcolumns)
83+
return frame._constructor(mdata, columns=mcolumns)
8584

8685

8786
def lreshape(data, groups, dropna=True, label=None):
@@ -152,8 +151,7 @@ def lreshape(data, groups, dropna=True, label=None):
152151
if not mask.all():
153152
mdata = {k: v[mask] for k, v in compat.iteritems(mdata)}
154153

155-
from pandas import DataFrame
156-
return DataFrame(mdata, columns=id_cols + pivot_cols)
154+
return data._constructor(mdata, columns=id_cols + pivot_cols)
157155

158156

159157
def wide_to_long(df, stubnames, i, j, sep="", suffix=r'\d+'):

pandas/core/reshape/reshape.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,23 @@ class _Unstacker(object):
3737
3838
Parameters
3939
----------
40+
values : ndarray
41+
Values of DataFrame to "Unstack"
42+
index : object
43+
Pandas ``Index``
4044
level : int or str, default last level
4145
Level to "unstack". Accepts a name for the level.
46+
value_columns : Index, optional
47+
Pandas ``Index`` or ``MultiIndex`` object if unstacking a DataFrame
48+
fill_value : scalar, optional
49+
Default value to fill in missing values if subgroups do not have the
50+
same set of labels. By default, missing values will be replaced with
51+
the default fill value for that data type, NaN for float, NaT for
52+
datetimelike, etc. For integer types, by default data will converted to
53+
float and missing values will be set to NaN.
54+
constructor : object, default DataFrame
55+
``Series``, ``DataFrame``, or subclass used to create unstacked
56+
response
4257
4358
Examples
4459
--------
@@ -69,7 +84,7 @@ class _Unstacker(object):
6984
"""
7085

7186
def __init__(self, values, index, level=-1, value_columns=None,
72-
fill_value=None):
87+
fill_value=None, constructor=None):
7388

7489
self.is_categorical = None
7590
self.is_sparse = is_sparse(values)
@@ -85,6 +100,13 @@ def __init__(self, values, index, level=-1, value_columns=None,
85100
self.values = values
86101
self.value_columns = value_columns
87102
self.fill_value = fill_value
103+
if constructor is None:
104+
if self.is_sparse:
105+
self.constructor = SparseDataFrame
106+
else:
107+
self.constructor = DataFrame
108+
else:
109+
self.constructor = constructor
88110

89111
if value_columns is None and values.shape[1] != 1: # pragma: no cover
90112
raise ValueError('must pass column labels for multi-column data')
@@ -179,8 +201,7 @@ def get_result(self):
179201
ordered=ordered)
180202
for i in range(values.shape[-1])]
181203

182-
klass = SparseDataFrame if self.is_sparse else DataFrame
183-
return klass(values, index=index, columns=columns)
204+
return self.constructor(values, index=index, columns=columns)
184205

185206
def get_new_values(self):
186207
values = self.values
@@ -380,8 +401,9 @@ def pivot(self, index=None, columns=None, values=None):
380401
index = self.index
381402
else:
382403
index = self[index]
383-
indexed = Series(self[values].values,
384-
index=MultiIndex.from_arrays([index, self[columns]]))
404+
indexed = self._constructor_sliced(
405+
self[values].values,
406+
index=MultiIndex.from_arrays([index, self[columns]]))
385407
return indexed.unstack(columns)
386408

387409

@@ -467,7 +489,8 @@ def unstack(obj, level, fill_value=None):
467489
return obj.T.stack(dropna=False)
468490
else:
469491
unstacker = _Unstacker(obj.values, obj.index, level=level,
470-
fill_value=fill_value)
492+
fill_value=fill_value,
493+
constructor=obj._constructor_expanddim)
471494
return unstacker.get_result()
472495

473496

@@ -476,12 +499,13 @@ def _unstack_frame(obj, level, fill_value=None):
476499
unstacker = partial(_Unstacker, index=obj.index,
477500
level=level, fill_value=fill_value)
478501
blocks = obj._data.unstack(unstacker)
479-
klass = type(obj)
502+
klass = obj._constructor
480503
return klass(blocks)
481504
else:
482505
unstacker = _Unstacker(obj.values, obj.index, level=level,
483506
value_columns=obj.columns,
484-
fill_value=fill_value)
507+
fill_value=fill_value,
508+
constructor=obj._constructor)
485509
return unstacker.get_result()
486510

487511

@@ -539,7 +563,7 @@ def factorize(index):
539563
new_values = new_values[mask]
540564
new_index = new_index[mask]
541565

542-
klass = type(frame)._constructor_sliced
566+
klass = frame._constructor_sliced
543567
return klass(new_values, index=new_index)
544568

545569

@@ -686,7 +710,7 @@ def _convert_level_number(level_num, columns):
686710
new_index = MultiIndex(levels=new_levels, labels=new_labels,
687711
names=new_names, verify_integrity=False)
688712

689-
result = DataFrame(new_data, index=new_index, columns=new_columns)
713+
result = frame._constructor(new_data, index=new_index, columns=new_columns)
690714

691715
# more efficient way to go about this? can do the whole masking biz but
692716
# will only save a small amount of time...

pandas/tests/frame/test_subclass.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from warnings import catch_warnings
66
import numpy as np
77

8-
from pandas import DataFrame, Series, MultiIndex, Panel
8+
from pandas import DataFrame, Series, MultiIndex, Panel, Index
99
import pandas as pd
1010
import pandas.util.testing as tm
1111

@@ -247,3 +247,180 @@ def test_subclass_sparse_transpose(self):
247247
[2, 5],
248248
[3, 6]])
249249
tm.assert_sp_frame_equal(ossdf.T, essdf)
250+
251+
def test_subclass_stack(self):
252+
# GH 15564
253+
df = tm.SubclassedDataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
254+
index=['a', 'b', 'c'],
255+
columns=['X', 'Y', 'Z'])
256+
257+
res = df.stack()
258+
exp = tm.SubclassedSeries(
259+
[1, 2, 3, 4, 5, 6, 7, 8, 9],
260+
index=[list('aaabbbccc'), list('XYZXYZXYZ')])
261+
262+
tm.assert_series_equal(res, exp)
263+
264+
def test_subclass_stack_multi(self):
265+
# GH 15564
266+
df = tm.SubclassedDataFrame([
267+
[10, 11, 12, 13],
268+
[20, 21, 22, 23],
269+
[30, 31, 32, 33],
270+
[40, 41, 42, 43]],
271+
index=MultiIndex.from_tuples(
272+
list(zip(list('AABB'), list('cdcd'))),
273+
names=['aaa', 'ccc']),
274+
columns=MultiIndex.from_tuples(
275+
list(zip(list('WWXX'), list('yzyz'))),
276+
names=['www', 'yyy']))
277+
278+
exp = tm.SubclassedDataFrame([
279+
[10, 12],
280+
[11, 13],
281+
[20, 22],
282+
[21, 23],
283+
[30, 32],
284+
[31, 33],
285+
[40, 42],
286+
[41, 43]],
287+
index=MultiIndex.from_tuples(list(zip(
288+
list('AAAABBBB'), list('ccddccdd'), list('yzyzyzyz'))),
289+
names=['aaa', 'ccc', 'yyy']),
290+
columns=Index(['W', 'X'], name='www'))
291+
292+
res = df.stack()
293+
tm.assert_frame_equal(res, exp)
294+
295+
res = df.stack('yyy')
296+
tm.assert_frame_equal(res, exp)
297+
298+
exp = tm.SubclassedDataFrame([
299+
[10, 11],
300+
[12, 13],
301+
[20, 21],
302+
[22, 23],
303+
[30, 31],
304+
[32, 33],
305+
[40, 41],
306+
[42, 43]],
307+
index=MultiIndex.from_tuples(list(zip(
308+
list('AAAABBBB'), list('ccddccdd'), list('WXWXWXWX'))),
309+
names=['aaa', 'ccc', 'www']),
310+
columns=Index(['y', 'z'], name='yyy'))
311+
312+
res = df.stack('www')
313+
tm.assert_frame_equal(res, exp)
314+
315+
def test_subclass_unstack(self):
316+
# GH 15564
317+
df = tm.SubclassedDataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
318+
index=['a', 'b', 'c'],
319+
columns=['X', 'Y', 'Z'])
320+
321+
res = df.unstack()
322+
exp = tm.SubclassedSeries(
323+
[1, 4, 7, 2, 5, 8, 3, 6, 9],
324+
index=[list('XXXYYYZZZ'), list('abcabcabc')])
325+
326+
tm.assert_series_equal(res, exp)
327+
328+
def test_subclass_unstack_multi(self):
329+
# GH 15564
330+
df = tm.SubclassedDataFrame([
331+
[10, 11, 12, 13],
332+
[20, 21, 22, 23],
333+
[30, 31, 32, 33],
334+
[40, 41, 42, 43]],
335+
index=MultiIndex.from_tuples(
336+
list(zip(list('AABB'), list('cdcd'))),
337+
names=['aaa', 'ccc']),
338+
columns=MultiIndex.from_tuples(
339+
list(zip(list('WWXX'), list('yzyz'))),
340+
names=['www', 'yyy']))
341+
342+
exp = tm.SubclassedDataFrame([
343+
[10, 20, 11, 21, 12, 22, 13, 23],
344+
[30, 40, 31, 41, 32, 42, 33, 43]],
345+
index=Index(['A', 'B'], name='aaa'),
346+
columns=MultiIndex.from_tuples(list(zip(
347+
list('WWWWXXXX'), list('yyzzyyzz'), list('cdcdcdcd'))),
348+
names=['www', 'yyy', 'ccc']))
349+
350+
res = df.unstack()
351+
tm.assert_frame_equal(res, exp)
352+
353+
res = df.unstack('ccc')
354+
tm.assert_frame_equal(res, exp)
355+
356+
exp = tm.SubclassedDataFrame([
357+
[10, 30, 11, 31, 12, 32, 13, 33],
358+
[20, 40, 21, 41, 22, 42, 23, 43]],
359+
index=Index(['c', 'd'], name='ccc'),
360+
columns=MultiIndex.from_tuples(list(zip(
361+
list('WWWWXXXX'), list('yyzzyyzz'), list('ABABABAB'))),
362+
names=['www', 'yyy', 'aaa']))
363+
364+
res = df.unstack('aaa')
365+
tm.assert_frame_equal(res, exp)
366+
367+
def test_subclass_pivot(self):
368+
# GH 15564
369+
df = tm.SubclassedDataFrame({
370+
'index': ['A', 'B', 'C', 'C', 'B', 'A'],
371+
'columns': ['One', 'One', 'One', 'Two', 'Two', 'Two'],
372+
'values': [1., 2., 3., 3., 2., 1.]})
373+
374+
pivoted = df.pivot(
375+
index='index', columns='columns', values='values')
376+
377+
expected = tm.SubclassedDataFrame({
378+
'One': {'A': 1., 'B': 2., 'C': 3.},
379+
'Two': {'A': 1., 'B': 2., 'C': 3.}})
380+
381+
expected.index.name, expected.columns.name = 'index', 'columns'
382+
383+
tm.assert_frame_equal(pivoted, expected)
384+
385+
def test_subclassed_melt(self):
386+
# GH 15564
387+
cheese = tm.SubclassedDataFrame({
388+
'first': ['John', 'Mary'],
389+
'last': ['Doe', 'Bo'],
390+
'height': [5.5, 6.0],
391+
'weight': [130, 150]})
392+
393+
melted = pd.melt(cheese, id_vars=['first', 'last'])
394+
395+
expected = tm.SubclassedDataFrame([
396+
['John', 'Doe', 'height', 5.5],
397+
['Mary', 'Bo', 'height', 6.0],
398+
['John', 'Doe', 'weight', 130],
399+
['Mary', 'Bo', 'weight', 150]],
400+
columns=['first', 'last', 'variable', 'value'])
401+
402+
tm.assert_frame_equal(melted, expected)
403+
404+
def test_subclassed_wide_to_long(self):
405+
# GH 9762
406+
407+
np.random.seed(123)
408+
x = np.random.randn(3)
409+
df = tm.SubclassedDataFrame({
410+
"A1970": {0: "a", 1: "b", 2: "c"},
411+
"A1980": {0: "d", 1: "e", 2: "f"},
412+
"B1970": {0: 2.5, 1: 1.2, 2: .7},
413+
"B1980": {0: 3.2, 1: 1.3, 2: .1},
414+
"X": dict(zip(range(3), x))})
415+
416+
df["id"] = df.index
417+
exp_data = {"X": x.tolist() + x.tolist(),
418+
"A": ['a', 'b', 'c', 'd', 'e', 'f'],
419+
"B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1],
420+
"year": [1970, 1970, 1970, 1980, 1980, 1980],
421+
"id": [0, 1, 2, 0, 1, 2]}
422+
expected = tm.SubclassedDataFrame(exp_data)
423+
expected = expected.set_index(['id', 'year'])[["X", "A", "B"]]
424+
long_frame = pd.wide_to_long(df, ["A", "B"], i="id", j="year")
425+
426+
tm.assert_frame_equal(long_frame, expected)

pandas/tests/series/test_subclass.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,31 @@ def test_indexing_sliced(self):
1313
res = s.loc[['a', 'b']]
1414
exp = tm.SubclassedSeries([1, 2], index=list('ab'))
1515
tm.assert_series_equal(res, exp)
16-
assert isinstance(res, tm.SubclassedSeries)
1716

1817
res = s.iloc[[2, 3]]
1918
exp = tm.SubclassedSeries([3, 4], index=list('cd'))
2019
tm.assert_series_equal(res, exp)
21-
assert isinstance(res, tm.SubclassedSeries)
2220

2321
res = s.loc[['a', 'b']]
2422
exp = tm.SubclassedSeries([1, 2], index=list('ab'))
2523
tm.assert_series_equal(res, exp)
26-
assert isinstance(res, tm.SubclassedSeries)
2724

2825
def test_to_frame(self):
2926
s = tm.SubclassedSeries([1, 2, 3, 4], index=list('abcd'), name='xxx')
3027
res = s.to_frame()
3128
exp = tm.SubclassedDataFrame({'xxx': [1, 2, 3, 4]}, index=list('abcd'))
3229
tm.assert_frame_equal(res, exp)
33-
assert isinstance(res, tm.SubclassedDataFrame)
30+
31+
def test_subclass_unstack(self):
32+
# GH 15564
33+
s = tm.SubclassedSeries(
34+
[1, 2, 3, 4], index=[list('aabb'), list('xyxy')])
35+
36+
res = s.unstack()
37+
exp = tm.SubclassedDataFrame(
38+
{'x': [1, 3], 'y': [2, 4]}, index=['a', 'b'])
39+
40+
tm.assert_frame_equal(res, exp)
3441

3542

3643
class TestSparseSeriesSubclassing(object):

0 commit comments

Comments
 (0)