@@ -37,8 +37,23 @@ class _Unstacker(object):
3737
3838 Parameters
3939 ----------
40+ values : ndarray
41+ Values of DataFrame to "Unstack"
42+ index : object
43+ Pandas ``Index``
4044 level : int or str, default last level
4145 Level to "unstack". Accepts a name for the level.
46+ value_columns : Index, optional
47+ Pandas ``Index`` or ``MultiIndex`` object if unstacking a DataFrame
48+ fill_value : scalar, optional
49+ Default value to fill in missing values if subgroups do not have the
50+ same set of labels. By default, missing values will be replaced with
51+ the default fill value for that data type, NaN for float, NaT for
52+ datetimelike, etc. For integer types, by default data will converted to
53+ float and missing values will be set to NaN.
54+ constructor : object
55+ Pandas ``DataFrame`` or subclass used to create unstacked
56+ response. If None, DataFrame or SparseDataFrame will be used.
4257
4358 Examples
4459 --------
@@ -69,7 +84,7 @@ class _Unstacker(object):
6984 """
7085
7186 def __init__ (self , values , index , level = - 1 , value_columns = None ,
72- fill_value = None ):
87+ fill_value = None , constructor = None ):
7388
7489 self .is_categorical = None
7590 self .is_sparse = is_sparse (values )
@@ -86,6 +101,14 @@ def __init__(self, values, index, level=-1, value_columns=None,
86101 self .value_columns = value_columns
87102 self .fill_value = fill_value
88103
104+ if constructor is None :
105+ if self .is_sparse :
106+ self .constructor = SparseDataFrame
107+ else :
108+ self .constructor = DataFrame
109+ else :
110+ self .constructor = constructor
111+
89112 if value_columns is None and values .shape [1 ] != 1 : # pragma: no cover
90113 raise ValueError ('must pass column labels for multi-column data' )
91114
@@ -173,8 +196,7 @@ def get_result(self):
173196 ordered = ordered )
174197 for i in range (values .shape [- 1 ])]
175198
176- klass = SparseDataFrame if self .is_sparse else DataFrame
177- return klass (values , index = index , columns = columns )
199+ return self .constructor (values , index = index , columns = columns )
178200
179201 def get_new_values (self ):
180202 values = self .values
@@ -374,8 +396,9 @@ def pivot(self, index=None, columns=None, values=None):
374396 index = self .index
375397 else :
376398 index = self [index ]
377- indexed = Series (self [values ].values ,
378- index = MultiIndex .from_arrays ([index , self [columns ]]))
399+ indexed = self ._constructor_sliced (
400+ self [values ].values ,
401+ index = MultiIndex .from_arrays ([index , self [columns ]]))
379402 return indexed .unstack (columns )
380403
381404
@@ -461,7 +484,8 @@ def unstack(obj, level, fill_value=None):
461484 return obj .T .stack (dropna = False )
462485 else :
463486 unstacker = _Unstacker (obj .values , obj .index , level = level ,
464- fill_value = fill_value )
487+ fill_value = fill_value ,
488+ constructor = obj ._constructor_expanddim )
465489 return unstacker .get_result ()
466490
467491
@@ -470,12 +494,12 @@ def _unstack_frame(obj, level, fill_value=None):
470494 unstacker = partial (_Unstacker , index = obj .index ,
471495 level = level , fill_value = fill_value )
472496 blocks = obj ._data .unstack (unstacker )
473- klass = type (obj )
474- return klass (blocks )
497+ return obj ._constructor (blocks )
475498 else :
476499 unstacker = _Unstacker (obj .values , obj .index , level = level ,
477500 value_columns = obj .columns ,
478- fill_value = fill_value )
501+ fill_value = fill_value ,
502+ constructor = obj ._constructor )
479503 return unstacker .get_result ()
480504
481505
@@ -528,8 +552,7 @@ def factorize(index):
528552 new_values = new_values [mask ]
529553 new_index = new_index [mask ]
530554
531- klass = type (frame )._constructor_sliced
532- return klass (new_values , index = new_index )
555+ return frame ._constructor_sliced (new_values , index = new_index )
533556
534557
535558def stack_multiple (frame , level , dropna = True ):
@@ -676,7 +699,7 @@ def _convert_level_number(level_num, columns):
676699 new_index = MultiIndex (levels = new_levels , labels = new_labels ,
677700 names = new_names , verify_integrity = False )
678701
679- result = DataFrame (new_data , index = new_index , columns = new_columns )
702+ result = frame . _constructor (new_data , index = new_index , columns = new_columns )
680703
681704 # more efficient way to go about this? can do the whole masking biz but
682705 # will only save a small amount of time...
0 commit comments