diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index be487e165c602..9298d4b058d98 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -626,6 +626,7 @@ Bug Fixes - Bug in ``.read_csv()`` with ``parse_dates`` when multiline headers are specified (:issue:`15376`) +- Bug in ``groupby.transform()`` that would coerce the resultant dtypes back to the original (:issue:`10972`) - Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`) diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 831ca3886773e..2c61a73d6814e 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -2890,32 +2890,32 @@ def transform(self, func, *args, **kwargs): lambda: getattr(self, func)(*args, **kwargs)) # reg transform - dtype = self._selected_obj.dtype - result = self._selected_obj.values.copy() - + klass = self._selected_obj.__class__ + results = [] wrapper = lambda x: func(x, *args, **kwargs) - for i, (name, group) in enumerate(self): + for name, group in self: object.__setattr__(group, 'name', name) res = wrapper(group) if hasattr(res, 'values'): res = res.values - # may need to astype - try: - common_type = np.common_type(np.array(res), result) - if common_type != result.dtype: - result = result.astype(common_type) - except: - pass - indexer = self._get_index(name) - result[indexer] = res + s = klass(res, indexer) + results.append(s) - result = _possibly_downcast_to_dtype(result, dtype) - return self._selected_obj.__class__(result, - index=self._selected_obj.index, - name=self._selected_obj.name) + from pandas.tools.concat import concat + result = concat(results).sort_index() + + # we will only try to coerce the result type if + # we have a numeric dtype + dtype = self._selected_obj.dtype + if is_numeric_dtype(dtype): + result = _possibly_downcast_to_dtype(result, dtype) + + result.name = self._selected_obj.name + result.index = self._selected_obj.index + return result def _transform_fast(self, func): """ diff --git a/pandas/tests/groupby/test_filters.py b/pandas/tests/groupby/test_filters.py index 46ddb5a5318fb..de6757786a363 100644 --- a/pandas/tests/groupby/test_filters.py +++ b/pandas/tests/groupby/test_filters.py @@ -216,6 +216,7 @@ def test_filter_against_workaround(self): grouper = s.apply(lambda x: np.round(x, -1)) grouped = s.groupby(grouper) f = lambda x: x.mean() > 10 + old_way = s[grouped.transform(f).astype('bool')] new_way = grouped.filter(f) assert_series_equal(new_way.sort_values(), old_way.sort_values()) diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index cf5e9eb26ff13..51920ec642705 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd from pandas.util import testing as tm -from pandas import Series, DataFrame, Timestamp, MultiIndex, concat +from pandas import Series, DataFrame, Timestamp, MultiIndex, concat, date_range from pandas.types.common import _ensure_platform_int from .common import MixIn, assert_fp_equal @@ -190,6 +190,43 @@ def test_transform_bug(self): expected = Series(np.arange(5, 0, step=-1), name='B') assert_series_equal(result, expected) + def test_transform_datetime_to_timedelta(self): + # GH 15429 + # transforming a datetime to timedelta + df = DataFrame(dict(A=Timestamp('20130101'), B=np.arange(5))) + expected = pd.Series([ + Timestamp('20130101') - Timestamp('20130101')] * 5, name='A') + + # this does date math without changing result type in transform + base_time = df['A'][0] + result = df.groupby('A')['A'].transform( + lambda x: x.max() - x.min() + base_time) - base_time + assert_series_equal(result, expected) + + # this does date math and causes the transform to return timedelta + result = df.groupby('A')['A'].transform(lambda x: x.max() - x.min()) + assert_series_equal(result, expected) + + def test_transform_datetime_to_numeric(self): + # GH 10972 + # convert dt to float + df = DataFrame({ + 'a': 1, 'b': date_range('2015-01-01', periods=2, freq='D')}) + result = df.groupby('a').b.transform( + lambda x: x.dt.dayofweek - x.dt.dayofweek.mean()) + + expected = Series([-0.5, 0.5], name='b') + assert_series_equal(result, expected) + + # convert dt to int + df = DataFrame({ + 'a': 1, 'b': date_range('2015-01-01', periods=2, freq='D')}) + result = df.groupby('a').b.transform( + lambda x: x.dt.dayofweek - x.dt.dayofweek.min()) + + expected = Series([0, 1], name='b') + assert_series_equal(result, expected) + def test_transform_multiple(self): grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])