diff --git a/RELEASE.rst b/RELEASE.rst index 307986ab81681..e4383d3bfcca8 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -79,6 +79,8 @@ pandas 0.11.1 spurious plots from showing up. - Added Faq section on repr display options, to help users customize their setup. - ``where`` operations that result in block splitting are much faster (GH3733_) + - ``groupby`` will now warn with a ``PerformanceWarning`` if an aggregate function + returns an array or list, instead of raising an error. (GH3788_) **API Changes** @@ -312,6 +314,7 @@ pandas 0.11.1 .. _GH3726: https://github.com/pydata/pandas/issues/3726 .. _GH3795: https://github.com/pydata/pandas/issues/3795 .. _GH3814: https://github.com/pydata/pandas/issues/3814 +.. _GH3788: https://github.com/pydata/pandas/issues/3788 pandas 0.11.0 ============= diff --git a/pandas/core/common.py b/pandas/core/common.py index 69f38bf0c7c61..b4395cc6e3bf1 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -42,6 +42,10 @@ class AmbiguousIndexError(PandasError, KeyError): pass +class PerformanceWarning(Warning): + "Baseclass for warnings about performance issues that affect speed, but not functionality." + pass + _POSSIBLY_CAST_DTYPES = set([ np.dtype(t) for t in ['M8[ns]','m8[ns]','O','int8','uint8','int16','uint16','int32','uint32','int64','uint64'] ]) _NS_DTYPE = np.dtype('M8[ns]') _TD_DTYPE = np.dtype('m8[ns]') diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 0be5d438e5e7c..777dfaefcb9ee 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -1,8 +1,10 @@ from itertools import izip import types +import warnings import numpy as np from pandas.core.categorical import Categorical +from pandas.core.common import PerformanceWarning from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame from pandas.core.index import Index, MultiIndex, _ensure_index @@ -18,7 +20,7 @@ import pandas.lib as lib import pandas.algos as _algos import pandas.hashtable as _hash - +_non_agg_warning = "Function does not produce aggregated values. Will not be able to optimize and may produce unexpected results." _agg_doc = """Aggregate using input function or dict of {column -> function} Parameters @@ -919,7 +921,7 @@ def _aggregate_series_pure_python(self, obj, func): res = func(group) if result is None: if isinstance(res, np.ndarray) or isinstance(res, list): - raise ValueError('Function does not reduce') + warnings.warn(_non_agg_warning, PerformanceWarning, stacklevel=2) result = np.empty(ngroups, dtype='O') counts[label] = group.shape[0] @@ -1508,7 +1510,7 @@ def _aggregate_named(self, func, *args, **kwargs): group.name = name output = func(group, *args, **kwargs) if isinstance(output, np.ndarray): - raise Exception('Must produce aggregated value') + warnings.warn(_non_agg_warning, PerformanceWarning, stacklevel=2) result[name] = self._try_cast(output, group) return result diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index b1b7b80e5fd23..24d7921e04c25 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -19,7 +19,7 @@ from pandas.sparse.api import SparseSeries, SparseDataFrame, SparsePanel from pandas.sparse.array import BlockIndex, IntIndex from pandas.tseries.api import PeriodIndex, DatetimeIndex -from pandas.core.common import adjoin, isnull, is_list_like +from pandas.core.common import adjoin, isnull, is_list_like, PerformanceWarning from pandas.core.algorithms import match, unique, factorize from pandas.core.categorical import Categorical from pandas.core.common import _asarray_tuplesafe, _try_sort @@ -64,7 +64,7 @@ class AttributeConflictWarning(Warning): pass the [%s] attribute of the existing index is [%s] which conflicts with the new [%s], resetting the attribute to None """ -class PerformanceWarning(Warning): pass +# for PerformanceWarning performance_doc = """ your performance may suffer as PyTables will pickle object types that it cannot map directly to c-types [inferred_type->%s,key->%s] [items->%s] diff --git a/pandas/io/tests/test_pytables.py b/pandas/io/tests/test_pytables.py index 8b3d4a475d952..ff8207b11f4e3 100644 --- a/pandas/io/tests/test_pytables.py +++ b/pandas/io/tests/test_pytables.py @@ -10,9 +10,11 @@ import pandas from pandas import (Series, DataFrame, Panel, MultiIndex, bdate_range, date_range, Index) +from pandas.core.common import PerformanceWarning from pandas.io.pytables import (HDFStore, get_store, Term, read_hdf, - IncompatibilityWarning, PerformanceWarning, + IncompatibilityWarning, AttributeConflictWarning) +from pandas.core.common import PerformanceWarning import pandas.util.testing as tm from pandas.tests.test_series import assert_series_equal from pandas.tests.test_frame import assert_frame_equal diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index f3a608b82e756..99ada1a5ee6b2 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -1,12 +1,13 @@ import nose import unittest +import warnings from datetime import datetime from numpy import nan from pandas import bdate_range from pandas.core.index import Index, MultiIndex -from pandas.core.common import rands +from pandas.core.common import rands, PerformanceWarning from pandas.core.api import Categorical, DataFrame from pandas.core.groupby import GroupByError, SpecificationError, DataError from pandas.core.series import Series @@ -131,8 +132,11 @@ def checkit(dtype): self.assertEqual(agged[1], 21) # corner cases - self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2) - + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + grouped.aggregate(lambda x: x * 2) + self.assertEqual(len(w), 1) + assert 'aggregate' in str(w[-1].message), "Wrong message: %r" % str(w[-1].message) for dtype in ['int64','int32','float64','float32']: checkit(dtype) @@ -334,8 +338,29 @@ def test_agg_period_index(self): def test_agg_must_agg(self): grouped = self.df.groupby('A')['C'] - self.assertRaises(Exception, grouped.agg, lambda x: x.describe()) - self.assertRaises(Exception, grouped.agg, lambda x: x.index[:2]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + grouped.agg(lambda x: x.describe()) + self.assertEqual(len(w), 1) + assert 'aggregate' in str(w[-1].message), "Wrong message: %r" % str(w[-1].message) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + grouped.agg(lambda x: x.index[:2]) + self.assertEqual(len(w), 1) + assert 'aggregate' in str(w[-1].message), "Wrong message: %r" % str(w[-1].message) + + # motivating example for #3788 + df = DataFrame([[1, np.array([10, 20, 30])], + [1, np.array([40, 50, 60])], + [2, np.array([20, 30, 40])]], + columns=['category', 'arraydata']) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + df.groupby('category').agg(sum) + self.assertEqual(len(w), 1) + assert 'aggregate' in str(w[-1].message), "Wrong message: %r" % str(w[-1].message) def test_agg_ser_multi_key(self): ser = self.df.C