diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index aad43e4f96b81..9c695148a75c0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2897,16 +2897,15 @@ def _get_cythonized_result( ids, _, ngroups = grouper.group_info output: dict[base.OutputKey, np.ndarray] = {} - base_func = getattr(libgroupby, how) - - error_msg = "" - for idx, obj in enumerate(self._iterate_slices()): - name = obj.name - values = obj._values - if numeric_only and not is_numeric_dtype(values.dtype): - continue + base_func = getattr(libgroupby, how) + base_func = partial(base_func, labels=ids) + if needs_ngroups: + base_func = partial(base_func, ngroups=ngroups) + if min_count is not None: + base_func = partial(base_func, min_count=min_count) + def blk_func(values: ArrayLike) -> ArrayLike: if aggregate: result_sz = ngroups else: @@ -2915,54 +2914,31 @@ def _get_cythonized_result( result = np.zeros(result_sz, dtype=cython_dtype) if needs_2d: result = result.reshape((-1, 1)) - func = partial(base_func, result) + func = partial(base_func, out=result) inferences = None if needs_counts: counts = np.zeros(self.ngroups, dtype=np.int64) - func = partial(func, counts) + func = partial(func, counts=counts) if needs_values: vals = values if pre_processing: - try: - vals, inferences = pre_processing(vals) - except TypeError as err: - error_msg = str(err) - howstr = how.replace("group_", "") - warnings.warn( - "Dropping invalid columns in " - f"{type(self).__name__}.{howstr} is deprecated. " - "In a future version, a TypeError will be raised. " - f"Before calling .{howstr}, select only columns which " - "should be valid for the function.", - FutureWarning, - stacklevel=3, - ) - continue + vals, inferences = pre_processing(vals) + vals = vals.astype(cython_dtype, copy=False) if needs_2d: vals = vals.reshape((-1, 1)) - func = partial(func, vals) - - func = partial(func, ids) - - if min_count is not None: - func = partial(func, min_count) + func = partial(func, values=vals) if needs_mask: mask = isna(values).view(np.uint8) - func = partial(func, mask) - - if needs_ngroups: - func = partial(func, ngroups) + func = partial(func, mask=mask) if needs_nullable: is_nullable = isinstance(values, BaseMaskedArray) func = partial(func, nullable=is_nullable) - if post_processing: - post_processing = partial(post_processing, nullable=is_nullable) func(**kwargs) # Call func to modify indexer values in place @@ -2973,9 +2949,38 @@ def _get_cythonized_result( result = algorithms.take_nd(values, result) if post_processing: - result = post_processing(result, inferences) + pp_kwargs = {} + if needs_nullable: + pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray) - key = base.OutputKey(label=name, position=idx) + result = post_processing(result, inferences, **pp_kwargs) + + return result + + error_msg = "" + for idx, obj in enumerate(self._iterate_slices()): + values = obj._values + + if numeric_only and not is_numeric_dtype(values.dtype): + continue + + try: + result = blk_func(values) + except TypeError as err: + error_msg = str(err) + howstr = how.replace("group_", "") + warnings.warn( + "Dropping invalid columns in " + f"{type(self).__name__}.{howstr} is deprecated. " + "In a future version, a TypeError will be raised. " + f"Before calling .{howstr}, select only columns which " + "should be valid for the function.", + FutureWarning, + stacklevel=3, + ) + continue + + key = base.OutputKey(label=obj.name, position=idx) output[key] = result # error_msg is "" on an frame/series with no rows or columns