diff --git a/doc/source/release.rst b/doc/source/release.rst index 159deaabb943f..937df6eb7cf3a 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -275,6 +275,8 @@ See :ref:`Internal Refactoring` have internal setitem_with_indexer in core/indexing to use Block.setitem - Fixed bug where thousands operator was not handled correctly for floating point numbers in csv_import (:issue:`4322`) + - The ``as_index=False`` argument to ``groupby`` now works with apply + (:issue:`4648`, :issue:`3417`) pandas 0.12 =========== diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 1f15f1a8ae10d..33df324fa94d7 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -104,6 +104,12 @@ def _last(x): return _last(x) +def _possibly_reset_index(obj, as_index): + if not as_index: + obj.reset_index(drop=True, inplace=True) + return obj + + class GroupBy(PandasObject): """ Class for grouping and aggregating relational data. See aggregate, @@ -525,7 +531,7 @@ def _concat_objects(self, keys, values, not_indexed_same=False): else: result = concat(values, axis=self.axis) - return result + return _possibly_reset_index(result, self.as_index) @Appender(GroupBy.__doc__) @@ -1602,6 +1608,7 @@ def filter(self, func, dropna=True, *args, **kwargs): else: return filtered.reindex(self.obj.index) # Fill with NaNs. + class NDFrameGroupBy(GroupBy): def _iterate_slices(self): @@ -1786,7 +1793,7 @@ def _aggregate_multiple_funcs(self, arg): grouper=self.grouper) results.append(colg.aggregate(arg)) keys.append(col) - except (TypeError, DataError) : + except (TypeError, DataError): pass except SpecificationError: raise @@ -1933,13 +1940,16 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False): except ValueError: #GH1738,, values is list of arrays of unequal lengths # fall through to the outer else caluse - return Series(values, index=key_index) + s = Series(values, index=key_index) + return _possibly_reset_index(s, self.as_index) - return DataFrame(stacked_values, index=index, - columns=columns).convert_objects() + df = DataFrame(stacked_values, index=index, + columns=columns).convert_objects() + return _possibly_reset_index(df, self.as_index) else: - return Series(values, index=key_index) + s = Series(values, index=key_index) + return _possibly_reset_index(s, self.as_index) else: # Handle cases like BinGrouper return self._concat_objects(keys, values, diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 9e7cdf9df2c6b..6f5127c99d471 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -4,12 +4,13 @@ from datetime import datetime from numpy import nan +from numpy.testing import assert_array_equal -from pandas import bdate_range +from pandas import bdate_range, date_range, period_range from pandas.core.index import Index, MultiIndex from pandas.core.common import rands from pandas.core.api import Categorical, DataFrame -from pandas.core.groupby import GroupByError, SpecificationError, DataError +from pandas.core.groupby import SpecificationError, DataError from pandas.core.series import Series from pandas.util.testing import (assert_panel_equal, assert_frame_equal, assert_series_equal, assert_almost_equal) @@ -21,9 +22,7 @@ from pandas.tools.merge import concat from collections import defaultdict import pandas.core.common as com -import pandas.core.datetools as dt import numpy as np -from numpy.testing import assert_equal import pandas.core.nanops as nanops @@ -336,7 +335,7 @@ def test_agg_datetimes_mixed(self): assert(len(gb1) == len(gb2)) def test_agg_period_index(self): - from pandas import period_range, PeriodIndex + from pandas import PeriodIndex prng = period_range('2012-1-1', freq='M', periods=3) df = DataFrame(np.random.randn(3, 2), index=prng) rs = df.groupby(level=0).sum() @@ -359,7 +358,7 @@ def test_agg_must_agg(self): def test_agg_ser_multi_key(self): ser = self.df.C f = lambda x: x.sum() - results = self.df.C.groupby([self.df.A, self.df.B]).aggregate(f) + results = ser.groupby([self.df.A, self.df.B]).aggregate(f) expected = self.df.groupby(['A', 'B']).sum()['C'] assert_series_equal(results, expected) @@ -1561,6 +1560,60 @@ def f(group): for key, group in grouped: assert_frame_equal(result.ix[key], f(group)) + def test_apply_as_index_is_false_frame(self): + indexes = (None, + list('abcdef'), + date_range('20010101', periods=6), + period_range('20010101', periods=6), + Index(np.random.randn(6), name='a' + tm.rands(4)), + MultiIndex.from_tuples(lzip(range(6), [1,1,1,2,2,2]), + names=['b' + tm.rands(4), + 'c' + tm.rands(4)])) + for index in indexes: + # test with head + df = DataFrame({'item_id': ['b', 'b', 'a', 'c', 'a', 'b'], + 'user_id': [1,2,1,1,3,1], 'time': lrange(6)}, + index=index) + gb = df.groupby('user_id', as_index=False) + assert_array_equal(gb.head(2).index, Index(np.arange(4))) + + # test with replace + with tm.assert_produces_warning(UserWarning): + res = gb.replace({'item_id': {'b': 'c'}}) + assert_array_equal(res.index, Index(np.arange(6))) + + # test with dropna + df.item_id[0] = np.nan + gb = df.groupby('user_id', as_index=False) + res = gb.dropna() + assert_array_equal(res.index, Index(np.arange(5))) + + def test_apply_as_index_is_false_multiple_funcs(self): + indexes = (None, + list('abcdefghi'), + date_range('20010101', periods=9), + period_range('20010101', periods=9), + Index(np.random.randn(9), name='a' + tm.rands(4)), + MultiIndex.from_tuples(lzip(range(9), [1,1,1,2,2,2,3,3,3]), + names=['b' + tm.rands(4), + 'c' + tm.rands(4)])) + for index in indexes: + # GH3417 + df = DataFrame({'a': [1,1,1,2,2,2,3,3,3], 'b': range(1, 10)}, + index=index) + + def f(x): + if x.a[:1] == 2: + mean, std = nan, nan + else: + mean, std = x.b.mean(), x.b.std() + return Series({'mean': mean, 'std': std}) + + gb = df.groupby('a', as_index=False) + + res = gb.apply(f) + assert_array_equal(res.index, Index(np.arange(3))) + def test_mutate_groups(self): # GH3380 @@ -2702,7 +2755,6 @@ def testit(label_list, shape): if __name__ == '__main__': - import nose nose.runmodule( argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure', '-s'], exit=False)