@@ -37,8 +37,22 @@ class _Unstacker(object):
3737
3838 Parameters
3939 ----------
40+ values : ndarray
41+ Values of DataFrame to "Unstack"
42+ index : object
43+ Pandas ``Index``
4044 level : int or str, default last level
4145 Level to "unstack". Accepts a name for the level.
46+ value_columns : Index, optional
47+ Pandas ``Index`` or ``MultiIndex`` object if unstacking a DataFrame
48+ fill_value : scalar, optional
49+ Default value to fill in missing values if subgroups do not have the
50+ same set of labels. By default, missing values will be replaced with
51+ the default fill value for that data type, NaN for float, NaT for
52+ datetimelike, etc. For integer types, by default data will converted to
53+ float and missing values will be set to NaN.
54+ obj : Series, DataFrame
55+ Object that is being unstacked (also used to get subclass)
4256
4357 Examples
4458 --------
@@ -68,8 +82,12 @@ class _Unstacker(object):
6882 unstacked : DataFrame
6983 """
7084
71- def __init__ (self , values , index , level = - 1 , value_columns = None ,
72- fill_value = None ):
85+ def __init__ (self , values = None , index = None , level = - 1 , value_columns = None ,
86+ fill_value = None , obj = None ):
87+
88+ if obj is not None :
89+ values = obj .values
90+ index = obj .index
7391
7492 self .is_categorical = None
7593 self .is_sparse = is_sparse (values )
@@ -85,6 +103,7 @@ def __init__(self, values, index, level=-1, value_columns=None,
85103 self .values = values
86104 self .value_columns = value_columns
87105 self .fill_value = fill_value
106+ self .obj = obj
88107
89108 if value_columns is None and values .shape [1 ] != 1 : # pragma: no cover
90109 raise ValueError ('must pass column labels for multi-column data' )
@@ -173,8 +192,11 @@ def get_result(self):
173192 ordered = ordered )
174193 for i in range (values .shape [- 1 ])]
175194
176- klass = SparseDataFrame if self .is_sparse else DataFrame
177- return klass (values , index = index , columns = columns )
195+ if isinstance (self .obj , Series ):
196+ constructor = self .obj ._constructor_expanddim
197+ else :
198+ constructor = self .obj ._constructor
199+ return constructor (values , index = index , columns = columns )
178200
179201 def get_new_values (self ):
180202 values = self .values
@@ -374,8 +396,9 @@ def pivot(self, index=None, columns=None, values=None):
374396 index = self .index
375397 else :
376398 index = self [index ]
377- indexed = Series (self [values ].values ,
378- index = MultiIndex .from_arrays ([index , self [columns ]]))
399+ indexed = self ._constructor_sliced (
400+ self [values ].values ,
401+ index = MultiIndex .from_arrays ([index , self [columns ]]))
379402 return indexed .unstack (columns )
380403
381404
@@ -460,7 +483,7 @@ def unstack(obj, level, fill_value=None):
460483 else :
461484 return obj .T .stack (dropna = False )
462485 else :
463- unstacker = _Unstacker (obj . values , obj . index , level = level ,
486+ unstacker = _Unstacker (obj = obj , level = level ,
464487 fill_value = fill_value )
465488 return unstacker .get_result ()
466489
@@ -470,10 +493,9 @@ def _unstack_frame(obj, level, fill_value=None):
470493 unstacker = partial (_Unstacker , index = obj .index ,
471494 level = level , fill_value = fill_value )
472495 blocks = obj ._data .unstack (unstacker )
473- klass = type (obj )
474- return klass (blocks )
496+ return obj ._constructor (blocks )
475497 else :
476- unstacker = _Unstacker (obj . values , obj . index , level = level ,
498+ unstacker = _Unstacker (obj = obj , level = level ,
477499 value_columns = obj .columns ,
478500 fill_value = fill_value )
479501 return unstacker .get_result ()
@@ -528,8 +550,7 @@ def factorize(index):
528550 new_values = new_values [mask ]
529551 new_index = new_index [mask ]
530552
531- klass = type (frame )._constructor_sliced
532- return klass (new_values , index = new_index )
553+ return frame ._constructor_sliced (new_values , index = new_index )
533554
534555
535556def stack_multiple (frame , level , dropna = True ):
@@ -675,7 +696,7 @@ def _convert_level_number(level_num, columns):
675696 new_index = MultiIndex (levels = new_levels , labels = new_labels ,
676697 names = new_names , verify_integrity = False )
677698
678- result = DataFrame (new_data , index = new_index , columns = new_columns )
699+ result = frame . _constructor (new_data , index = new_index , columns = new_columns )
679700
680701 # more efficient way to go about this? can do the whole masking biz but
681702 # will only save a small amount of time...
0 commit comments