Skip to content

Commit e20199f

Browse files
author
y-p
committed
BUG: add sanity check to groupby agg function, outside main loop
1 parent 4d92994 commit e20199f

File tree

3 files changed

+102
-24
lines changed

3 files changed

+102
-24
lines changed

pandas/core/groupby.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def _groupby_function(name, alias, npfunc, numeric_only=True,
5757
def f(self):
5858
try:
5959
return self._cython_agg_general(alias, numeric_only=numeric_only)
60+
except AssertionError as e:
61+
raise SpecificationError(str(e))
6062
except Exception:
6163
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
6264
if _convert:
@@ -348,7 +350,7 @@ def mean(self):
348350
"""
349351
try:
350352
return self._cython_agg_general('mean')
351-
except DataError:
353+
except GroupByError:
352354
raise
353355
except Exception: # pragma: no cover
354356
f = lambda x: x.mean(axis=self.axis)
@@ -362,7 +364,7 @@ def median(self):
362364
"""
363365
try:
364366
return self._cython_agg_general('median')
365-
except DataError:
367+
except GroupByError:
366368
raise
367369
except Exception: # pragma: no cover
368370
f = lambda x: x.median(axis=self.axis)
@@ -462,7 +464,10 @@ def _cython_agg_general(self, how, numeric_only=True):
462464
if numeric_only and not is_numeric:
463465
continue
464466

465-
result, names = self.grouper.aggregate(obj.values, how)
467+
try:
468+
result, names = self.grouper.aggregate(obj.values, how)
469+
except AssertionError as e:
470+
raise GroupByError(str(e))
466471
output[name] = result
467472

468473
if len(output) == 0:
@@ -1725,9 +1730,10 @@ def _aggregate_multiple_funcs(self, arg):
17251730
grouper=self.grouper)
17261731
results.append(colg.aggregate(arg))
17271732
keys.append(col)
1728-
except (TypeError, DataError):
1733+
except (TypeError, DataError) :
17291734
pass
1730-
1735+
except SpecificationError:
1736+
raise
17311737
result = concat(results, keys=keys, axis=1)
17321738

17331739
return result

pandas/src/generate_code.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,9 @@ def group_last_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
628628
ndarray[%(dest_type2)s, ndim=2] resx
629629
ndarray[int64_t, ndim=2] nobs
630630
631+
if not len(values) == len(labels):
632+
raise AssertionError("len(index) != len(labels)")
633+
631634
nobs = np.zeros((<object> out).shape, dtype=np.int64)
632635
resx = np.empty_like(out)
633636
@@ -763,6 +766,9 @@ def group_nth_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
763766
ndarray[%(dest_type2)s, ndim=2] resx
764767
ndarray[int64_t, ndim=2] nobs
765768
769+
if not len(values) == len(labels):
770+
raise AssertionError("len(index) != len(labels)")
771+
766772
nobs = np.zeros((<object> out).shape, dtype=np.int64)
767773
resx = np.empty_like(out)
768774
@@ -805,6 +811,9 @@ def group_add_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
805811
%(dest_type2)s val, count
806812
ndarray[%(dest_type2)s, ndim=2] sumx, nobs
807813
814+
if not len(values) == len(labels):
815+
raise AssertionError("len(index) != len(labels)")
816+
808817
nobs = np.zeros_like(out)
809818
sumx = np.zeros_like(out)
810819
@@ -918,6 +927,9 @@ def group_prod_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
918927
%(dest_type2)s val, count
919928
ndarray[%(dest_type2)s, ndim=2] prodx, nobs
920929
930+
if not len(values) == len(labels):
931+
raise AssertionError("len(index) != len(labels)")
932+
921933
nobs = np.zeros_like(out)
922934
prodx = np.ones_like(out)
923935
@@ -1028,6 +1040,9 @@ def group_var_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
10281040
%(dest_type2)s val, ct
10291041
ndarray[%(dest_type2)s, ndim=2] nobs, sumx, sumxx
10301042
1043+
if not len(values) == len(labels):
1044+
raise AssertionError("len(index) != len(labels)")
1045+
10311046
nobs = np.zeros_like(out)
10321047
sumx = np.zeros_like(out)
10331048
sumxx = np.zeros_like(out)
@@ -1223,6 +1238,9 @@ def group_max_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
12231238
%(dest_type2)s val, count
12241239
ndarray[%(dest_type2)s, ndim=2] maxx, nobs
12251240
1241+
if not len(values) == len(labels):
1242+
raise AssertionError("len(index) != len(labels)")
1243+
12261244
nobs = np.zeros_like(out)
12271245
12281246
maxx = np.empty_like(out)
@@ -1345,6 +1363,9 @@ def group_min_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
13451363
%(dest_type2)s val, count
13461364
ndarray[%(dest_type2)s, ndim=2] minx, nobs
13471365
1366+
if not len(values) == len(labels):
1367+
raise AssertionError("len(index) != len(labels)")
1368+
13481369
nobs = np.zeros_like(out)
13491370
13501371
minx = np.empty_like(out)
@@ -1402,6 +1423,9 @@ def group_mean_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
14021423
%(dest_type2)s val, count
14031424
ndarray[%(dest_type2)s, ndim=2] sumx, nobs
14041425
1426+
if not len(values) == len(labels):
1427+
raise AssertionError("len(index) != len(labels)")
1428+
14051429
nobs = np.zeros_like(out)
14061430
sumx = np.zeros_like(out)
14071431

0 commit comments

Comments
 (0)