Skip to content

Commit 6a90819

Browse files
committed
BUG: fix groupby.aggregate resulting dtype coercion, xref pandas-dev#11444
make sure .size includes the name of the grouped
1 parent f5f244a commit 6a90819

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

doc/source/whatsnew/v0.20.0.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ Bug Fixes
628628

629629

630630
- Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`)
631-
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`)
631+
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`, :issue:`11444`)
632632

633633

634634
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)

pandas/core/groupby.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,23 @@ def _index_with_as_index(self, b):
767767
new.names = gp.names + original.names
768768
return new
769769

770-
def _try_cast(self, result, obj):
770+
def _try_cast(self, result, obj, numeric_only=False):
771771
"""
772772
try to cast the result to our obj original type,
773773
we may have roundtripped thru object in the mean-time
774774
775+
if numeric_only is True, then only try to cast numerics
776+
and not datetimelikes
777+
775778
"""
776779
if obj.ndim > 1:
777780
dtype = obj.values.dtype
778781
else:
779782
dtype = obj.dtype
780783

781784
if not is_scalar(result):
782-
result = _possibly_downcast_to_dtype(result, dtype)
785+
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
786+
result = _possibly_downcast_to_dtype(result, dtype)
783787

784788
return result
785789

@@ -830,7 +834,7 @@ def _python_agg_general(self, func, *args, **kwargs):
830834
for name, obj in self._iterate_slices():
831835
try:
832836
result, counts = self.grouper.agg_series(obj, f)
833-
output[name] = self._try_cast(result, obj)
837+
output[name] = self._try_cast(result, obj, numeric_only=True)
834838
except TypeError:
835839
continue
836840

@@ -1117,7 +1121,9 @@ def sem(self, ddof=1):
11171121
@Appender(_doc_template)
11181122
def size(self):
11191123
"""Compute group sizes"""
1120-
return self.grouper.size()
1124+
result = self.grouper.size()
1125+
result.name = getattr(self, 'name', None)
1126+
return result
11211127

11221128
sum = _groupby_function('sum', 'add', np.sum)
11231129
prod = _groupby_function('prod', 'prod', np.prod)
@@ -1689,7 +1695,9 @@ def size(self):
16891695
ids, _, ngroup = self.group_info
16901696
ids = _ensure_platform_int(ids)
16911697
out = np.bincount(ids[ids != -1], minlength=ngroup or None)
1692-
return Series(out, index=self.result_index, dtype='int64')
1698+
return Series(out,
1699+
index=self.result_index,
1700+
dtype='int64')
16931701

16941702
@cache_readonly
16951703
def _max_groupsize(self):
@@ -2908,7 +2916,8 @@ def transform(self, func, *args, **kwargs):
29082916
result = concat(results).sort_index()
29092917

29102918
# we will only try to coerce the result type if
2911-
# we have a numeric dtype
2919+
# we have a numeric dtype, as these are *always* udfs
2920+
# the cython take a different path (and casting)
29122921
dtype = self._selected_obj.dtype
29132922
if is_numeric_dtype(dtype):
29142923
result = _possibly_downcast_to_dtype(result, dtype)

pandas/tests/groupby/test_aggregate.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,29 @@ def test_agg_dict_parameter_cast_result_dtypes(self):
154154
assert_series_equal(grouped.time.last(), exp['time'])
155155
assert_series_equal(grouped.time.agg('last'), exp['time'])
156156

157+
# count
158+
exp = pd.Series([2, 2, 2, 2],
159+
index=Index(list('ABCD'), name='class'),
160+
name='time')
161+
assert_series_equal(grouped.time.agg(len), exp)
162+
assert_series_equal(grouped.time.size(), exp)
163+
164+
exp = pd.Series([0, 1, 1, 2],
165+
index=Index(list('ABCD'), name='class'),
166+
name='time')
167+
assert_series_equal(grouped.time.count(), exp)
168+
169+
def test_agg_cast_results_dtypes(self):
170+
# similar to GH12821
171+
# xref #11444
172+
u = [datetime(2015, x + 1, 1) for x in range(12)]
173+
v = list('aaabbbbbbccd')
174+
df = pd.DataFrame({'X': v, 'Y': u})
175+
176+
result = df.groupby('X')['Y'].agg(len)
177+
expected = df.groupby('X')['Y'].count()
178+
assert_series_equal(result, expected)
179+
157180
def test_agg_must_agg(self):
158181
grouped = self.df.groupby('A')['C']
159182
self.assertRaises(Exception, grouped.agg, lambda x: x.describe())

0 commit comments

Comments
 (0)