From 7b8f5aea25cb52812af0e22ba38e1ac8eee391ea Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Fri, 13 Mar 2020 22:07:02 -0700 Subject: [PATCH 1/4] Centralize numba checks and have groupby.transform accept engine and engine_kwargs --- pandas/core/groupby/generic.py | 16 +++++++--- pandas/core/numba_.py | 58 ++++++++++++++++++++++++++++++++++ pandas/core/window/numba_.py | 42 ++++++------------------ 3 files changed, 79 insertions(+), 37 deletions(-) create mode 100644 pandas/core/numba_.py diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4102b8527b6aa..7e6d4b18382e3 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -461,7 +461,7 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): @@ -480,7 +480,9 @@ def transform(self, func, *args, **kwargs): result = getattr(self, func)(*args, **kwargs) return self._transform_fast(result, func) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, engine="cython", engine_kwargs=None, *args, **kwargs + ): """ Transform with a non-str `func`. """ @@ -1355,7 +1357,9 @@ def first_not_none(values): # Handle cases like BinGrouper return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, engine="cython", engine_kwargs=None, *args, **kwargs + ): from pandas.core.reshape.concat import concat applied = [] @@ -1411,13 +1415,15 @@ def _transform_general(self, func, *args, **kwargs): @Substitution(klass="DataFrame", selected="") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" diff --git a/pandas/core/numba_.py b/pandas/core/numba_.py new file mode 100644 index 0000000000000..e4debab2c22ee --- /dev/null +++ b/pandas/core/numba_.py @@ -0,0 +1,58 @@ +"""Common utilities for Numba operations""" +import types +from typing import Callable, Dict, Optional + +import numpy as np + +from pandas.compat._optional import import_optional_dependency + + +def check_kwargs_and_nopython( + kwargs: Optional[Dict] = None, nopython: Optional[bool] = None +): + if kwargs and nopython: + raise ValueError( + "numba does not support kwargs with nopython=True: " + "https://github.com/numba/numba/issues/2916" + ) + + +def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): + """ + Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. + """ + if engine_kwargs is None: + engine_kwargs = {} + + nopython = engine_kwargs.get("nopython", True) + nogil = engine_kwargs.get("nogil", False) + parallel = engine_kwargs.get("parallel", False) + return nopython, nogil, parallel + + +def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool): + """ + JIT the user's function given the configurable arguments. + """ + numba = import_optional_dependency("numba") + + if isinstance(func, numba.targets.registry.CPUDispatcher): + # Don't jit a user passed jitted function + numba_func = func + else: + + @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) + def numba_func(data, *_args): + if getattr(np, func.__name__, False) is func or isinstance( + func, types.BuiltinFunctionType + ): + jf = func + else: + jf = numba.jit(func, nopython=nopython, nogil=nogil) + + def impl(data, *_args): + return jf(data, *_args) + + return impl + + return numba_func diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index d6e8194c861fa..e37a1c755e662 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -1,10 +1,14 @@ -import types from typing import Any, Callable, Dict, Optional, Tuple import numpy as np from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency +from pandas.core.numba_ import ( + check_kwargs_and_nopython, + get_jit_arguments, + jit_user_function, +) def make_rolling_apply( @@ -37,30 +41,13 @@ def make_rolling_apply( """ numba = import_optional_dependency("numba") + numba_func = jit_user_function(func, nopython, nogil, parallel) + if parallel: loop_range = numba.prange else: loop_range = range - if isinstance(func, numba.targets.registry.CPUDispatcher): - # Don't jit a user passed jitted function - numba_func = func - else: - - @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) - def numba_func(window, *_args): - if getattr(np, func.__name__, False) is func or isinstance( - func, types.BuiltinFunctionType - ): - jf = func - else: - jf = numba.jit(func, nopython=nopython, nogil=nogil) - - def impl(window, *_args): - return jf(window, *_args) - - return impl - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int, @@ -110,17 +97,8 @@ def generate_numba_apply_func( ------- Numba function """ - if engine_kwargs is None: - engine_kwargs = {} - - nopython = engine_kwargs.get("nopython", True) - nogil = engine_kwargs.get("nogil", False) - parallel = engine_kwargs.get("parallel", False) - - if kwargs and nopython: - raise ValueError( - "numba does not support kwargs with nopython=True: " - "https://github.com/numba/numba/issues/2916" - ) + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) return make_rolling_apply(func, args, nogil, parallel, nopython) From 7ffb61e92297be9d68c4b49fa211055c88d4c882 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 17 Mar 2020 00:30:33 -0700 Subject: [PATCH 2/4] Remove engine kwarg for now --- pandas/core/groupby/generic.py | 16 +++----- pandas/core/window/numba_.py | 75 ++++++++++------------------------ 2 files changed, 27 insertions(+), 64 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 7e6d4b18382e3..4102b8527b6aa 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -461,7 +461,7 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) - def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): + def transform(self, func, *args, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): @@ -480,9 +480,7 @@ def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): result = getattr(self, func)(*args, **kwargs) return self._transform_fast(result, func) - def _transform_general( - self, func, engine="cython", engine_kwargs=None, *args, **kwargs - ): + def _transform_general(self, func, *args, **kwargs): """ Transform with a non-str `func`. """ @@ -1357,9 +1355,7 @@ def first_not_none(values): # Handle cases like BinGrouper return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) - def _transform_general( - self, func, engine="cython", engine_kwargs=None, *args, **kwargs - ): + def _transform_general(self, func, *args, **kwargs): from pandas.core.reshape.concat import concat applied = [] @@ -1415,15 +1411,13 @@ def _transform_general( @Substitution(klass="DataFrame", selected="") @Appender(_transform_template) - def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): + def transform(self, func, *args, **kwargs): # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general( - func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs - ) + return self._transform_general(func, *args, **kwargs) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index e37a1c755e662..f513a0f894be1 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -11,38 +11,44 @@ ) -def make_rolling_apply( - func: Callable[..., Scalar], +def generate_numba_apply_func( args: Tuple, - nogil: bool, - parallel: bool, - nopython: bool, + kwargs: Dict[str, Any], + func: Callable[..., Scalar], + engine_kwargs: Optional[Dict[str, bool]], ): """ - Creates a JITted rolling apply function with a JITted version of - the user's function. + Generate a numba jitted apply function specified by values from engine_kwargs. + + 1. jit the user's function + 2. Return a rolling apply function with the jitted function inline + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. Parameters ---------- - func : function - function to be applied to each window and will be JITed args : tuple *args to be passed into the function - nogil : bool - nogil parameter from engine_kwargs for numba.jit - parallel : bool - parallel parameter from engine_kwargs for numba.jit - nopython : bool - nopython parameter from engine_kwargs for numba.jit + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each window and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ - numba = import_optional_dependency("numba") + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) numba_func = jit_user_function(func, nopython, nogil, parallel) + numba = import_optional_dependency("numba") + if parallel: loop_range = numba.prange else: @@ -65,40 +71,3 @@ def roll_apply( return result return roll_apply - - -def generate_numba_apply_func( - args: Tuple, - kwargs: Dict[str, Any], - func: Callable[..., Scalar], - engine_kwargs: Optional[Dict[str, bool]], -): - """ - Generate a numba jitted apply function specified by values from engine_kwargs. - - 1. jit the user's function - 2. Return a rolling apply function with the jitted function inline - - Configurations specified in engine_kwargs apply to both the user's - function _AND_ the rolling apply function. - - Parameters - ---------- - args : tuple - *args to be passed into the function - kwargs : dict - **kwargs to be passed into the function - func : function - function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit - - Returns - ------- - Numba function - """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - check_kwargs_and_nopython(kwargs, nopython) - - return make_rolling_apply(func, args, nogil, parallel, nopython) From b722731813f8339ebbec36808b515be9a68ff23c Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 17 Mar 2020 09:17:22 -0700 Subject: [PATCH 3/4] isort --- pandas/core/window/numba_.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index f513a0f894be1..4c060b6c50446 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -4,6 +4,7 @@ from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency + from pandas.core.numba_ import ( check_kwargs_and_nopython, get_jit_arguments, From d47a90d6e6c53bb1849b0782b821a4c13ee84663 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 17 Mar 2020 20:00:45 -0700 Subject: [PATCH 4/4] Move numba_.py to core/util --- pandas/core/{ => util}/numba_.py | 0 pandas/core/window/numba_.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename pandas/core/{ => util}/numba_.py (100%) diff --git a/pandas/core/numba_.py b/pandas/core/util/numba_.py similarity index 100% rename from pandas/core/numba_.py rename to pandas/core/util/numba_.py diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index 4c060b6c50446..5d35ec7457ab0 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -5,7 +5,7 @@ from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency -from pandas.core.numba_ import ( +from pandas.core.util.numba_ import ( check_kwargs_and_nopython, get_jit_arguments, jit_user_function,