diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index 9279d8b0288c4..523939b39c580 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -336,7 +336,8 @@ Bug Fixes - +- Bug in ``GroupBy.filter()`` where fast path vs. slow path made the filter + return a non scalar value that appeared valid but wasnt' (:issue:`7870`). diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index eabe1b43004df..93be135e9ff40 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -2945,48 +2945,34 @@ def filter(self, func, dropna=True, *args, **kwargs): >>> grouped = df.groupby(lambda x: mapping[x]) >>> grouped.filter(lambda x: x['A'].sum() + x['B'].sum() > 0) """ - from pandas.tools.merge import concat indices = [] obj = self._selected_obj gen = self.grouper.get_iterator(obj, axis=self.axis) - fast_path, slow_path = self._define_paths(func, *args, **kwargs) - - path = None for name, group in gen: object.__setattr__(group, 'name', name) - if path is None: - # Try slow path and fast path. - try: - path, res = self._choose_path(fast_path, slow_path, group) - except Exception: # pragma: no cover - res = fast_path(group) - path = fast_path - else: - res = path(group) + res = func(group) - def add_indices(): - indices.append(self._get_index(name)) + try: + res = res.squeeze() + except AttributeError: # allow e.g., scalars and frames to pass + pass # interpret the result of the filter - if isinstance(res, (bool, np.bool_)): - if res: - add_indices() + if (isinstance(res, (bool, np.bool_)) or + np.isscalar(res) and isnull(res)): + if res and notnull(res): + indices.append(self._get_index(name)) else: - if getattr(res, 'ndim', None) == 1: - val = res.ravel()[0] - if val and notnull(val): - add_indices() - else: - - # in theory you could do .all() on the boolean result ? - raise TypeError("the filter must return a boolean result") + # non scalars aren't allowed + raise TypeError("filter function returned a %s, " + "but expected a scalar bool" % + type(res).__name__) - filtered = self._apply_filter(indices, dropna) - return filtered + return self._apply_filter(indices, dropna) class DataFrameGroupBy(NDFrameGroupBy): diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 5adaacbeb9d29..f958d5481ad33 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -3968,6 +3968,32 @@ def test_filter_has_access_to_grouped_cols(self): filt = g.filter(lambda x: x['A'].sum() == 2) assert_frame_equal(filt, df.iloc[[0, 1]]) + def test_filter_enforces_scalarness(self): + df = pd.DataFrame([ + ['best', 'a', 'x'], + ['worst', 'b', 'y'], + ['best', 'c', 'x'], + ['best','d', 'y'], + ['worst','d', 'y'], + ['worst','d', 'y'], + ['best','d', 'z'], + ], columns=['a', 'b', 'c']) + with tm.assertRaisesRegexp(TypeError, 'filter function returned a.*'): + df.groupby('c').filter(lambda g: g['a'] == 'best') + + def test_filter_non_bool_raises(self): + df = pd.DataFrame([ + ['best', 'a', 1], + ['worst', 'b', 1], + ['best', 'c', 1], + ['best','d', 1], + ['worst','d', 1], + ['worst','d', 1], + ['best','d', 1], + ], columns=['a', 'b', 'c']) + with tm.assertRaisesRegexp(TypeError, 'filter function returned a.*'): + df.groupby('a').filter(lambda g: g.c.mean()) + def test_index_label_overlaps_location(self): # checking we don't have any label/location confusion in the # the wake of GH5375