Skip to content

Commit cdc51d4

Browse files
committed
BUG/API groupby head and tail act like filter, since they dont aggregage, fixes column selection
1 parent e91a091 commit cdc51d4

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

pandas/core/groupby.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def head(self, n=5):
587587
"""
588588
Returns first n rows of each group.
589589
590-
Essentially equivalent to ``.apply(lambda x: x.head(n))``
590+
Essentially equivalent to ``.apply(lambda x: x.head(n))`` except ignores as_index flag.
591591
592592
Example
593593
-------
@@ -599,17 +599,15 @@ def head(self, n=5):
599599
0 1 2
600600
2 5 6
601601
>>> df.groupby('A').head(1)
602-
A B
603-
A
604-
1 0 1 2
605-
5 2 5 6
602+
A B
603+
0 1 2
604+
2 5 6
606605
607606
"""
607+
obj = self._selected_obj
608608
rng = np.arange(self.grouper._max_groupsize, dtype='int64')
609609
in_head = self._cumcount_array(rng) < n
610-
head = self.obj[in_head]
611-
if self.as_index:
612-
head.index = self._index_with_as_index(in_head)
610+
head = obj[in_head]
613611
return head
614612

615613
def tail(self, n=5):
@@ -628,17 +626,15 @@ def tail(self, n=5):
628626
0 1 2
629627
2 5 6
630628
>>> df.groupby('A').head(1)
631-
A B
632-
A
633-
1 0 1 2
634-
5 2 5 6
629+
A B
630+
0 1 2
631+
2 5 6
635632
636633
"""
634+
obj = self._selected_obj
637635
rng = np.arange(0, -self.grouper._max_groupsize, -1, dtype='int64')
638636
in_tail = self._cumcount_array(rng, ascending=False) > -n
639-
tail = self.obj[in_tail]
640-
if self.as_index:
641-
tail.index = self._index_with_as_index(in_tail)
637+
tail = obj[in_tail]
642638
return tail
643639

644640
def _cumcount_array(self, arr, **kwargs):
@@ -654,6 +650,13 @@ def _cumcount_array(self, arr, **kwargs):
654650
cumcounts[v] = arr[len(v)-1::-1]
655651
return cumcounts
656652

653+
@cache_readonly
654+
def _selected_obj(self):
655+
if self._selection is None or isinstance(self.obj, Series):
656+
return self.obj
657+
else:
658+
return self.obj[self._selection]
659+
657660
def _index_with_as_index(self, b):
658661
"""
659662
Take boolean mask of index to be returned from apply, if as_index=True

pandas/tests/test_groupby.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,12 +1315,10 @@ def test_groupby_as_index_apply(self):
13151315
g_not_as = df.groupby('user_id', as_index=False)
13161316

13171317
res_as = g_as.head(2).index
1318-
exp_as = MultiIndex.from_tuples([(1, 0), (2, 1), (1, 2), (3, 4)])
1319-
assert_index_equal(res_as, exp_as)
1320-
13211318
res_not_as = g_not_as.head(2).index
1322-
exp_not_as = Index([0, 1, 2, 4])
1323-
assert_index_equal(res_not_as, exp_not_as)
1319+
exp = Index([0, 1, 2, 4])
1320+
assert_index_equal(res_as, exp)
1321+
assert_index_equal(res_not_as, exp)
13241322

13251323
res_as_apply = g_as.apply(lambda x: x.head(2)).index
13261324
res_not_as_apply = g_not_as.apply(lambda x: x.head(2)).index
@@ -1355,11 +1353,8 @@ def test_groupby_head_tail(self):
13551353
assert_frame_equal(df, g_not_as.head(7)) # contains all
13561354
assert_frame_equal(df, g_not_as.tail(7))
13571355

1358-
# as_index=True, yuck
1359-
# prepend the A column as an index, in a roundabout way
1360-
df_as = df.copy()
1361-
df_as.index = df.set_index('A', append=True,
1362-
drop=False).index.swaplevel(0, 1)
1356+
# as_index=True, (used to be different)
1357+
df_as = df
13631358

13641359
assert_frame_equal(df_as.loc[[0, 2]], g_as.head(1))
13651360
assert_frame_equal(df_as.loc[[1, 2]], g_as.tail(1))
@@ -1373,6 +1368,18 @@ def test_groupby_head_tail(self):
13731368
assert_frame_equal(df_as, g_as.head(7)) # contains all
13741369
assert_frame_equal(df_as, g_as.tail(7))
13751370

1371+
# test with selection
1372+
assert_frame_equal(g_as[[]].head(1), df_as.loc[[0,2], []])
1373+
assert_frame_equal(g_as[['A']].head(1), df_as.loc[[0,2], ['A']])
1374+
assert_frame_equal(g_as[['B']].head(1), df_as.loc[[0,2], ['B']])
1375+
assert_frame_equal(g_as[['A', 'B']].head(1), df_as.loc[[0,2]])
1376+
1377+
assert_frame_equal(g_not_as[[]].head(1), df_as.loc[[0,2], []])
1378+
assert_frame_equal(g_not_as[['A']].head(1), df_as.loc[[0,2], ['A']])
1379+
assert_frame_equal(g_not_as[['B']].head(1), df_as.loc[[0,2], ['B']])
1380+
assert_frame_equal(g_not_as[['A', 'B']].head(1), df_as.loc[[0,2]])
1381+
1382+
13761383
def test_groupby_multiple_key(self):
13771384
df = tm.makeTimeDataFrame()
13781385
grouped = df.groupby([lambda x: x.year,

0 commit comments

Comments
 (0)