diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index ddd7972b17..990add9692 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -2,77 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -import hashlib import sys from abc import ABCMeta, abstractmethod from numba.core.caching import CacheImpl, IndexDataCacheFile -from numba.core.serialize import dumps from numba_dpex import config -from numba_dpex.core.types import USMNdArray - - -def build_key( - argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None -): - """Constructs a key from python function, context, backend, the device - type and execution queue. - - Compute index key for the given argument types and codegen. It includes a - description of the OS, target architecture and hashes of the bytecode for - the function and, if the function has a __closure__, a hash of the - cell_contents.type - - Args: - argtypes : A tuple of numba types corresponding to the arguments to the - compiled function. - pyfunc : The Python function that is to be compiled and cached. - codegen (numba.core.codegen.Codegen): - The codegen object found from the target context. - backend (enum, optional): A 'backend_type' enum. - Defaults to None. - device_type (enum, optional): A 'device_type' enum. - Defaults to None. - exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object. - - Returns: - tuple: A tuple of return type, argtpes, magic_tuple of codegen - and another tuple of hashcodes from bytecode and cell_contents. - """ - - codebytes = pyfunc.__code__.co_code - if pyfunc.__closure__ is not None: - try: - cvars = tuple([x.cell_contents for x in pyfunc.__closure__]) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - argtylist = list(argtypes) - for i, argty in enumerate(argtylist): - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - argtylist[i] = (argty.ndim, argty.dtype, argty.layout) - - argtypes = tuple(argtylist) - - return ( - argtypes, - codegen.magic_tuple(), - backend, - device_type, - exec_queue, - ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ), - ) class _CacheImpl(CacheImpl): @@ -475,8 +410,13 @@ def put(self, key, value): self._name, len(self._lookup), str(key) ) ) - self._lookup[key].value = value - self.get(key) + node = self._lookup[key] + node.value = value + + if node is not self._tail: + self._unlink_node(node) + self._append_tail(node) + return if key in self._evicted: diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index cc6b61bcaa..b1bbef7b63 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -14,7 +14,7 @@ from numba.core.types import void from numba_dpex import NdRange, Range, config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.exceptions import ( ComputeFollowsDataInferenceError, @@ -34,6 +34,11 @@ from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel from numba_dpex.core.types import USMNdArray +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) def get_ordered_arg_access_types(pyfunc, access_types): @@ -85,6 +90,8 @@ def __init__( self._global_range = None self._local_range = None + self._func_hash = create_func_hash(pyfunc) + # caching related attributes if not config.ENABLE_CACHE: self._cache = NullCache() @@ -151,7 +158,7 @@ def cache(self): def cache_hits(self): return self._cache_hits - def _compile_and_cache(self, argtypes, cache): + def _compile_and_cache(self, argtypes, cache, key=None): """Helper function to compile the Python function or Numba FunctionIR object passed to a JitKernel and store it in an internal cache. """ @@ -171,11 +178,13 @@ def _compile_and_cache(self, argtypes, cache): device_driver_ir_module = kernel.device_driver_ir_module kernel_module_name = kernel.module_name - key = build_key( - tuple(argtypes), - self.pyfunc, - kernel.target_context.codegen(), - ) + if not key: + stripped_argtypes = strip_usm_metadata(argtypes) + codegen_magic_tuple = kernel.target_context.codegen().magic_tuple() + key = build_key( + stripped_argtypes, codegen_magic_tuple, self._func_hash + ) + cache.put(key, (device_driver_ir_module, kernel_module_name)) return device_driver_ir_module, kernel_module_name @@ -604,12 +613,12 @@ def __call__(self, *args): self.kernel_name, backend, JitKernel._supported_backends ) - # load the kernel from cache - key = build_key( - tuple(argtypes), - self.pyfunc, - dpex_kernel_target.target_context.codegen(), + # Generate key used for cache lookup + stripped_argtypes = strip_usm_metadata(argtypes) + codegen_magic_tuple = ( + dpex_kernel_target.target_context.codegen().magic_tuple() ) + key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash) # If the JitKernel was specialized then raise exception if argtypes # do not match one of the specialized versions. @@ -630,15 +639,11 @@ def __call__(self, *args): device_driver_ir_module, kernel_module_name, ) = self._compile_and_cache( - argtypes=argtypes, - cache=self._cache, + argtypes=argtypes, cache=self._cache, key=key ) kernel_bundle_key = build_key( - tuple(argtypes), - self.pyfunc, - dpex_kernel_target.target_context.codegen(), - exec_queue=exec_queue, + stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash ) artifact = self._kernel_bundle_cache.get(kernel_bundle_key) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index acd9b09bb1..8537a91742 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -5,14 +5,18 @@ """_summary_ """ - from numba.core import sigutils, types from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate from numba_dpex import config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_kernel_target +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) from numba_dpex.utils import npytypes_array_to_dpex_array @@ -91,6 +95,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._debug = debug self._enable_cache = enable_cache + self._func_hash = create_func_hash(pyfunc) + if not config.ENABLE_CACHE: self._cache = NullCache() elif self._enable_cache: @@ -132,11 +138,14 @@ def compile(self, args): dpex_kernel_target.typing_context.resolve_argument_type(arg) for arg in args ] - key = build_key( - tuple(argtypes), - self._pyfunc, - dpex_kernel_target.target_context.codegen(), + + # Generate key used for cache lookup + stripped_argtypes = strip_usm_metadata(argtypes) + codegen_magic_tuple = ( + dpex_kernel_target.target_context.codegen().magic_tuple() ) + key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash) + cres = self._cache.get(key) if cres is None: self._cache_hits += 1 diff --git a/numba_dpex/core/utils/__init__.py b/numba_dpex/core/utils/__init__.py index 78bf969d57..e736b48ce6 100644 --- a/numba_dpex/core/utils/__init__.py +++ b/numba_dpex/core/utils/__init__.py @@ -2,9 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from .caching_utils import build_key, create_func_hash, strip_usm_metadata from .suai_helper import SyclUSMArrayInterface, get_info_from_suai __all__ = [ "get_info_from_suai", "SyclUSMArrayInterface", + "create_func_hash", + "strip_usm_metadata", + "build_key", ] diff --git a/numba_dpex/core/utils/caching_utils.py b/numba_dpex/core/utils/caching_utils.py new file mode 100644 index 0000000000..b48f460b9b --- /dev/null +++ b/numba_dpex/core/utils/caching_utils.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import hashlib + +from numba.core.serialize import dumps + +from numba_dpex.core.types import USMNdArray + + +def build_key(*args): + """Constructs key from variable list of args + + Args: + *args: List of components to construct key + Return: + Tuple of args + """ + return tuple(args) + + +def create_func_hash(pyfunc): + """Creates a tuple of sha256 hashes out of code and + variable bytes extracted from the compiled funtion. + + Args: + pyfunc: Python function object + Return: + Tuple of hashes of code and variable bytes + """ + codebytes = pyfunc.__code__.co_code + if pyfunc.__closure__ is not None: + try: + cvars = tuple([x.cell_contents for x in pyfunc.__closure__]) + # Note: cloudpickle serializes a function differently depending + # on how the process is launched; e.g. multiprocessing.Process + cvarbytes = dumps(cvars) + except: + cvarbytes = b"" # a temporary solution for function template + else: + cvarbytes = b"" + + return ( + hashlib.sha256(codebytes).hexdigest(), + hashlib.sha256(cvarbytes).hexdigest(), + ) + + +def strip_usm_metadata(argtypes): + """Convert the USMNdArray to an abridged type that disregards the + usm_type, device, queue, address space attributes. + + Args: + argtypes: List of types + + Return: + Tuple of types after removing USM metadata from USMNdArray type + """ + + stripped_argtypes = [] + for argty in argtypes: + if isinstance(argty, USMNdArray): + stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout)) + else: + stripped_argtypes.append(argty) + + return tuple(stripped_argtypes)