diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index ab3316e7fca4c..ea9b06a58be92 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -30,6 +30,7 @@ Other enhancements ^^^^^^^^^^^^^^^^^^ - :class:`pandas.api.typing.FrozenList` is available for typing the outputs of :attr:`MultiIndex.names`, :attr:`MultiIndex.codes` and :attr:`MultiIndex.levels` (:issue:`58237`) - :class:`pandas.api.typing.SASReader` is available for typing the output of :func:`read_sas` (:issue:`55689`) +- :meth:`DataFrame.apply` accepts Numba as an engine by passing the JIT decorator directly, e.g. ``df.apply(func, engine=numba.jit)`` (:issue:`61458`) - :meth:`pandas.api.interchange.from_dataframe` now uses the `PyCapsule Interface `_ if available, only falling back to the Dataframe Interchange Protocol if that fails (:issue:`60739`) - Added :meth:`.Styler.to_typst` to write Styler objects to file, buffer or string in Typst format (:issue:`57617`) - Added missing :meth:`pandas.Series.info` to API reference (:issue:`60926`) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 2c96f1ef020ac..760fd111f21ce 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -178,6 +178,60 @@ def apply( """ +class NumbaExecutionEngine(BaseExecutionEngine): + """ + Numba-based execution engine for pandas apply and map operations. + """ + + @staticmethod + def map( + data: np.ndarray | Series | DataFrame, + func, + args: tuple, + kwargs: dict, + decorator: Callable | None, + skip_na: bool, + ): + """ + Elementwise map for the Numba engine. Currently not supported. + """ + raise NotImplementedError("Numba map is not implemented yet.") + + @staticmethod + def apply( + data: np.ndarray | Series | DataFrame, + func, + args: tuple, + kwargs: dict, + decorator: Callable, + axis: int | str, + ): + """ + Apply `func` along the given axis using Numba. + """ + engine_kwargs: dict[str, bool] | None = ( + decorator if isinstance(decorator, dict) else None + ) + + looper_args, looper_kwargs = prepare_function_arguments( + func, + args, + kwargs, + num_required_args=1, + ) + # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has + # incompatible type "Callable[..., Any] | str | list[Callable + # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | + # list[Callable[..., Any] | str]]"; expected "Hashable" + nb_looper = generate_apply_looper( + func, + **get_jit_arguments(engine_kwargs), + ) + result = nb_looper(data, axis, *looper_args) + # If we made the result 2-D, squeeze it back to 1-D + return np.squeeze(result) + + def frame_apply( obj: DataFrame, func: AggFuncType, @@ -1094,23 +1148,31 @@ def wrapper(*args, **kwargs): return wrapper if engine == "numba": - args, kwargs = prepare_function_arguments( - self.func, # type: ignore[arg-type] - self.args, - self.kwargs, - num_required_args=1, - ) - # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has - # incompatible type "Callable[..., Any] | str | list[Callable - # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | - # list[Callable[..., Any] | str]]"; expected "Hashable" - nb_looper = generate_apply_looper( - self.func, # type: ignore[arg-type] - **get_jit_arguments(engine_kwargs), - ) - result = nb_looper(self.values, self.axis, *args) - # If we made the result 2-D, squeeze it back to 1-D - result = np.squeeze(result) + try: + import numba + + if not hasattr(numba.jit, "__pandas_udf__"): + numba.jit.__pandas_udf__ = NumbaExecutionEngine + result = numba.jit.__pandas_udf__.apply( + self.values, + self.func, + self.args, + self.kwargs, + engine_kwargs, + self.axis, + ) + else: + raise ImportError + except ImportError: + engine_obj = NumbaExecutionEngine() + result = engine_obj.apply( + self.values, + self.func, + self.args, + self.kwargs, + engine_kwargs, + self.axis, + ) else: result = np.apply_along_axis( wrap_function(self.func),