diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py new file mode 100644 index 0000000000000..e4debab2c22ee --- /dev/null +++ b/pandas/core/util/numba_.py @@ -0,0 +1,58 @@ +"""Common utilities for Numba operations""" +import types +from typing import Callable, Dict, Optional + +import numpy as np + +from pandas.compat._optional import import_optional_dependency + + +def check_kwargs_and_nopython( + kwargs: Optional[Dict] = None, nopython: Optional[bool] = None +): + if kwargs and nopython: + raise ValueError( + "numba does not support kwargs with nopython=True: " + "https://github.com/numba/numba/issues/2916" + ) + + +def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): + """ + Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. + """ + if engine_kwargs is None: + engine_kwargs = {} + + nopython = engine_kwargs.get("nopython", True) + nogil = engine_kwargs.get("nogil", False) + parallel = engine_kwargs.get("parallel", False) + return nopython, nogil, parallel + + +def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool): + """ + JIT the user's function given the configurable arguments. + """ + numba = import_optional_dependency("numba") + + if isinstance(func, numba.targets.registry.CPUDispatcher): + # Don't jit a user passed jitted function + numba_func = func + else: + + @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) + def numba_func(data, *_args): + if getattr(np, func.__name__, False) is func or isinstance( + func, types.BuiltinFunctionType + ): + jf = func + else: + jf = numba.jit(func, nopython=nopython, nogil=nogil) + + def impl(data, *_args): + return jf(data, *_args) + + return impl + + return numba_func diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index d6e8194c861fa..5d35ec7457ab0 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -1,4 +1,3 @@ -import types from typing import Any, Callable, Dict, Optional, Tuple import numpy as np @@ -6,35 +5,49 @@ from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency +from pandas.core.util.numba_ import ( + check_kwargs_and_nopython, + get_jit_arguments, + jit_user_function, +) -def make_rolling_apply( - func: Callable[..., Scalar], + +def generate_numba_apply_func( args: Tuple, - nogil: bool, - parallel: bool, - nopython: bool, + kwargs: Dict[str, Any], + func: Callable[..., Scalar], + engine_kwargs: Optional[Dict[str, bool]], ): """ - Creates a JITted rolling apply function with a JITted version of - the user's function. + Generate a numba jitted apply function specified by values from engine_kwargs. + + 1. jit the user's function + 2. Return a rolling apply function with the jitted function inline + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. Parameters ---------- - func : function - function to be applied to each window and will be JITed args : tuple *args to be passed into the function - nogil : bool - nogil parameter from engine_kwargs for numba.jit - parallel : bool - parallel parameter from engine_kwargs for numba.jit - nopython : bool - nopython parameter from engine_kwargs for numba.jit + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each window and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) + + numba_func = jit_user_function(func, nopython, nogil, parallel) + numba = import_optional_dependency("numba") if parallel: @@ -42,25 +55,6 @@ def make_rolling_apply( else: loop_range = range - if isinstance(func, numba.targets.registry.CPUDispatcher): - # Don't jit a user passed jitted function - numba_func = func - else: - - @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) - def numba_func(window, *_args): - if getattr(np, func.__name__, False) is func or isinstance( - func, types.BuiltinFunctionType - ): - jf = func - else: - jf = numba.jit(func, nopython=nopython, nogil=nogil) - - def impl(window, *_args): - return jf(window, *_args) - - return impl - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int, @@ -78,49 +72,3 @@ def roll_apply( return result return roll_apply - - -def generate_numba_apply_func( - args: Tuple, - kwargs: Dict[str, Any], - func: Callable[..., Scalar], - engine_kwargs: Optional[Dict[str, bool]], -): - """ - Generate a numba jitted apply function specified by values from engine_kwargs. - - 1. jit the user's function - 2. Return a rolling apply function with the jitted function inline - - Configurations specified in engine_kwargs apply to both the user's - function _AND_ the rolling apply function. - - Parameters - ---------- - args : tuple - *args to be passed into the function - kwargs : dict - **kwargs to be passed into the function - func : function - function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit - - Returns - ------- - Numba function - """ - if engine_kwargs is None: - engine_kwargs = {} - - nopython = engine_kwargs.get("nopython", True) - nogil = engine_kwargs.get("nogil", False) - parallel = engine_kwargs.get("parallel", False) - - if kwargs and nopython: - raise ValueError( - "numba does not support kwargs with nopython=True: " - "https://github.com/numba/numba/issues/2916" - ) - - return make_rolling_apply(func, args, nogil, parallel, nopython)