Skip to content

Commit 69356d8

Browse files
committed
BUG: fix groupby.aggregate resulting dtype coercion, xref pandas-dev#11444
1 parent f5f244a commit 69356d8

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

doc/source/whatsnew/v0.20.0.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ Bug Fixes
628628

629629

630630
- Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`)
631-
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`)
631+
- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`, :issue:`11444`)
632632

633633

634634
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)

pandas/core/groupby.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,23 @@ def _index_with_as_index(self, b):
767767
new.names = gp.names + original.names
768768
return new
769769

770-
def _try_cast(self, result, obj):
770+
def _try_cast(self, result, obj, numeric_only=False):
771771
"""
772772
try to cast the result to our obj original type,
773773
we may have roundtripped thru object in the mean-time
774774
775+
if numeric_only is True, then only try to cast numerics
776+
and not datetimelikes
777+
775778
"""
776779
if obj.ndim > 1:
777780
dtype = obj.values.dtype
778781
else:
779782
dtype = obj.dtype
780783

781784
if not is_scalar(result):
782-
result = _possibly_downcast_to_dtype(result, dtype)
785+
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
786+
result = _possibly_downcast_to_dtype(result, dtype)
783787

784788
return result
785789

@@ -830,7 +834,7 @@ def _python_agg_general(self, func, *args, **kwargs):
830834
for name, obj in self._iterate_slices():
831835
try:
832836
result, counts = self.grouper.agg_series(obj, f)
833-
output[name] = self._try_cast(result, obj)
837+
output[name] = self._try_cast(result, obj, numeric_only=True)
834838
except TypeError:
835839
continue
836840

@@ -2908,7 +2912,8 @@ def transform(self, func, *args, **kwargs):
29082912
result = concat(results).sort_index()
29092913

29102914
# we will only try to coerce the result type if
2911-
# we have a numeric dtype
2915+
# we have a numeric dtype, as these are *always* udfs
2916+
# the cython take a different path (and casting)
29122917
dtype = self._selected_obj.dtype
29132918
if is_numeric_dtype(dtype):
29142919
result = _possibly_downcast_to_dtype(result, dtype)

pandas/tests/groupby/test_aggregate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ def test_agg_dict_parameter_cast_result_dtypes(self):
154154
assert_series_equal(grouped.time.last(), exp['time'])
155155
assert_series_equal(grouped.time.agg('last'), exp['time'])
156156

157+
# count
158+
exp = pd.Series([0, 1, 1, 2],
159+
index=Index(list('ABCD'), name='class'),
160+
name='time')
161+
assert_series_equal(grouped.time.count(), exp)
162+
assert_series_equal(grouped.time.agg(len), exp)
163+
164+
def test_agg_cast_results_dtypes(self):
165+
# xref #11444
166+
u = [datetime(2015, x + 1, 1) for x in range(12)]
167+
v = list('aaabbbbbbccd')
168+
df = pd.DataFrame({'X': v, 'Y': u})
169+
170+
result = df.groupby('X')['Y'].agg(len)
171+
expected = df.groupby('X')['Y'].count()
172+
assert_series_equal(result, expected)
173+
157174
def test_agg_must_agg(self):
158175
grouped = self.df.groupby('A')['C']
159176
self.assertRaises(Exception, grouped.agg, lambda x: x.describe())

0 commit comments

Comments
 (0)