diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 1350848741ad1..8afb0e116fea8 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -603,6 +603,23 @@ def _cython_operation( kind, values, how, axis, min_count, **kwargs ) + elif values.ndim == 1: + # expand to 2d, dispatch, then squeeze if appropriate + values2d = values[None, :] + res = self._cython_operation( + kind=kind, + values=values2d, + how=how, + axis=1, + min_count=min_count, + **kwargs, + ) + if res.shape[0] == 1: + return res[0] + + # otherwise we have OHLC + return res.T + is_datetimelike = needs_i8_conversion(dtype) if is_datetimelike: @@ -623,22 +640,20 @@ def _cython_operation( values = values.astype(object) arity = self._cython_arity.get(how, 1) + ngroups = self.ngroups - vdim = values.ndim - swapped = False - if vdim == 1: - values = values[:, None] - out_shape = (self.ngroups, arity) + assert axis == 1 + values = values.T + if how == "ohlc": + out_shape = (ngroups, 4) + elif arity > 1: + raise NotImplementedError( + "arity of more than 1 is not supported for the 'how' argument" + ) + elif kind == "transform": + out_shape = values.shape else: - if axis > 0: - swapped = True - assert axis == 1, axis - values = values.T - if arity > 1: - raise NotImplementedError( - "arity of more than 1 is not supported for the 'how' argument" - ) - out_shape = (self.ngroups,) + values.shape[1:] + out_shape = (ngroups,) + values.shape[1:] func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric) @@ -652,13 +667,11 @@ def _cython_operation( codes, _, _ = self.group_info + result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) if kind == "aggregate": - result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) counts = np.zeros(self.ngroups, dtype=np.int64) result = self._aggregate(result, counts, values, codes, func, min_count) elif kind == "transform": - result = maybe_fill(np.empty(values.shape, dtype=out_dtype)) - # TODO: min_count result = self._transform( result, values, codes, func, is_datetimelike, **kwargs @@ -674,11 +687,7 @@ def _cython_operation( assert result.ndim != 2 result = result[counts > 0] - if vdim == 1 and arity == 1: - result = result[:, 0] - - if swapped: - result = result.swapaxes(0, axis) + result = result.T if how not in base.cython_cast_blocklist: # e.g. if we are int64 and need to restore to datetime64/timedelta64 diff --git a/pandas/core/missing.py b/pandas/core/missing.py index 41d7fed66469d..feaecec382704 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -861,7 +861,4 @@ def _rolling_window(a: np.ndarray, window: int): # https://stackoverflow.com/a/6811241 shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) strides = a.strides + (a.strides[-1],) - # error: Module has no attribute "stride_tricks" - return np.lib.stride_tricks.as_strided( # type: ignore[attr-defined] - a, shape=shape, strides=strides - ) + return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)