diff --git a/numba_dpex/config.py b/numba_dpex/config.py index 7fc121d059..146cb35301 100644 --- a/numba_dpex/config.py +++ b/numba_dpex/config.py @@ -99,7 +99,7 @@ def __getattr__(name): ENABLE_CACHE = _readenv("NUMBA_DPEX_ENABLE_CACHE", int, 1) # Capacity of the cache, execute it like: # NUMBA_DPEX_CACHE_SIZE=20 python -CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 10) +CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 128) TESTING_SKIP_NO_DPNP = _readenv("NUMBA_DPEX_TESTING_SKIP_NO_DPNP", int, 0) TESTING_SKIP_NO_DEBUGGING = _readenv( diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index 50480c2c33..feb6ba6bbc 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable from inspect import signature from warnings import warn @@ -32,6 +33,7 @@ ) from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel +from numba_dpex.core.kernel_interface.utils import NdRange, Range from numba_dpex.core.types import USMNdArray @@ -468,51 +470,87 @@ def __getitem__(self, args): global_range and local_range attributes initialized. """ - if isinstance(args, int): - self._global_range = [args] - self._local_range = None - elif isinstance(args, tuple) or isinstance(args, list): - if len(args) == 1 and all(isinstance(v, int) for v in args): - self._global_range = list(args) - self._local_range = None - elif len(args) == 2: - gr = args[0] - lr = args[1] - if isinstance(gr, int): - self._global_range = [gr] - elif len(gr) != 0 and all(isinstance(v, int) for v in gr): - self._global_range = list(gr) - else: - raise IllegalRangeValueError(kernel_name=self.kernel_name) + if isinstance(args, Range): + # we need inversions, see github issue #889 + self._global_range = list(args)[::-1] + elif isinstance(args, NdRange): + # we need inversions, see github issue #889 + self._global_range = list(args.global_range)[::-1] + self._local_range = list(args.local_range)[::-1] + else: + if ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], int) + and isinstance(args[1], int) + ): + warn( + "Ambiguous kernel launch paramters. If your data have " + + "dimensions > 1, include a default/empty local_range:\n" + + " [(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE]()\n" + + "otherwise your code might produce erroneous results.", + DeprecationWarning, + stacklevel=2, + ) + self._global_range = [args[0]] + self._local_range = [args[1]] + return self + + warn( + "The current syntax for specification of kernel lauch " + + "parameters is deprecated. Users should set the kernel " + + "parameters through Range/NdRange classes.\n" + + "Example:\n" + + " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n" + + " # for global range only\n" + + " [Range(X,Y)]()\n" + + " # or,\n" + + " # for both global and local ranges\n" + + " [NdRange((X,Y), (P,Q))]()", + DeprecationWarning, + stacklevel=2, + ) - if isinstance(lr, int): - self._local_range = [lr] - elif isinstance(lr, list) and len(lr) == 0: - # deprecation warning + args = [args] if not isinstance(args, Iterable) else args + nargs = len(args) + + # Check if the kernel enquing arguments are sane + if nargs < 1 or nargs > 2: + raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name) + + g_range = ( + [args[0]] if not isinstance(args[0], Iterable) else args[0] + ) + # If the optional local size argument is provided + l_range = None + if nargs == 2: + if args[1] != []: + l_range = ( + [args[1]] + if not isinstance(args[1], Iterable) + else args[1] + ) + else: warn( - "Specifying the local range as an empty list " - "(DEFAULT_LOCAL_SIZE) is deprecated. The kernel will " - "be executed as a basic data-parallel kernel over the " - "global range. Specify a valid local range to execute " - "the kernel as an ND-range kernel.", + "Empty local_range calls are deprecated. Please use Range/NdRange " + + "to specify the kernel launch parameters:\n" + + "Example:\n" + + " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n" + + " # for global range only\n" + + " [Range(X,Y)]()\n" + + " # or,\n" + + " # for both global and local ranges\n" + + " [NdRange((X,Y), (P,Q))]()", DeprecationWarning, stacklevel=2, ) - self._local_range = None - elif len(lr) != 0 and all(isinstance(v, int) for v in lr): - self._local_range = list(lr) - else: - raise IllegalRangeValueError(kernel_name=self.kernel_name) - else: - raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name) - else: - raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name) - # FIXME:[::-1] is done as OpenCL and SYCl have different orders when - # it comes to specifying dimensions. - self._global_range = list(self._global_range)[::-1] - if self._local_range: - self._local_range = list(self._local_range)[::-1] + if len(g_range) < 1: + raise IllegalRangeValueError(kernel_name=self.kernel_name) + + # we need inversions, see github issue #889 + self._global_range = list(g_range)[::-1] + self._local_range = list(l_range)[::-1] if l_range else None return self diff --git a/numba_dpex/core/kernel_interface/utils.py b/numba_dpex/core/kernel_interface/utils.py new file mode 100644 index 0000000000..55431cc2bb --- /dev/null +++ b/numba_dpex/core/kernel_interface/utils.py @@ -0,0 +1,218 @@ +from collections.abc import Iterable + + +class Range(tuple): + """A data structure to encapsulate a single kernel lauch parameter. + + The range is an abstraction that describes the number of elements + in each dimension of buffers and index spaces. It can contain + 1, 2, or 3 numbers, dependending on the dimensionality of the + object it describes. + + This is just a wrapper class on top of a 3-tuple. The kernel launch + parameter is consisted of three int's. This class basically mimics + the behavior of `sycl::range`. + """ + + def __new__(cls, dim0, dim1=None, dim2=None): + """Constructs a 1, 2, or 3 dimensional range. + + Args: + dim0 (int): The range of the first dimension. + dim1 (int, optional): The range of second dimension. + Defaults to None. + dim2 (int, optional): The range of the third dimension. + Defaults to None. + + Raises: + TypeError: If dim0 is not an int. + TypeError: If dim1 is not an int. + TypeError: If dim2 is not an int. + """ + if not isinstance(dim0, int): + raise TypeError("dim0 of a Range must be an int.") + _values = [dim0] + if dim1: + if not isinstance(dim1, int): + raise TypeError("dim1 of a Range must be an int.") + _values.append(dim1) + if dim2: + if not isinstance(dim2, int): + raise TypeError("dim2 of a Range must be an int.") + _values.append(dim2) + return super(Range, cls).__new__(cls, tuple(_values)) + + def get(self, index): + """Returns the range of a single dimension. + + Args: + index (int): The index of the dimension, i.e. [0,2] + + Returns: + int: The range of the dimension indexed by `index`. + """ + return self[index] + + def size(self): + """Returns the size of a range. + + Returns the size of a range by multiplying + the range of the individual dimensions. + + Returns: + int: The size of a range. + """ + n = len(self) + if n > 2: + return self[0] * self[1] * self[2] + elif n > 1: + return self[0] * self[1] + else: + return self[0] + + +class NdRange: + """A class to encapsulate all kernel launch parameters. + + The NdRange defines the index space for a work group as well as + the global index space. It is passed to parallel_for to execute + a kernel on a set of work items. + + This class basically contains two Range object, one for the global_range + and the other for the local_range. The global_range parameter contains + the global index space and the local_range parameter contains the index + space of a work group. This class mimics the behavior of `sycl::nd_range` + class. + """ + + def __init__(self, global_size, local_size): + """Constructor for NdRange class. + + Args: + global_size (Range or tuple of int's): The values for + the global_range. + local_size (Range or tuple of int's, optional): The values for + the local_range. Defaults to None. + """ + if isinstance(global_size, Range): + self._global_range = global_size + elif isinstance(global_size, Iterable): + self._global_range = Range(*global_size) + else: + TypeError("Unknwon argument type for NdRange global_size.") + + if isinstance(local_size, Range): + self._local_range = local_size + elif isinstance(local_size, Iterable): + self._local_range = Range(*local_size) + else: + TypeError("Unknwon argument type for NdRange local_size.") + + @property + def global_range(self): + """Accessor for global_range. + + Returns: + Range: The `global_range` `Range` object. + """ + return self._global_range + + @property + def local_range(self): + """Accessor for local_range. + + Returns: + Range: The `local_range` `Range` object. + """ + return self._local_range + + def get_global_range(self): + """Returns a Range defining the index space. + + Returns: + Range: A `Range` object defining the index space. + """ + return self._global_range + + def get_local_range(self): + """Returns a Range defining the index space of a work group. + + Returns: + Range: A `Range` object to specify index space of a work group. + """ + return self._local_range + + def __str__(self): + """str() function for NdRange class. + + Returns: + str: str representation for NdRange class. + """ + return ( + "(" + str(self._global_range) + ", " + str(self._local_range) + ")" + ) + + def __repr__(self): + """repr() function for NdRange class. + + Returns: + str: str representation for NdRange class. + """ + return self.__str__() + + +if __name__ == "__main__": + r1 = Range(1) + print("r1 =", r1) + + r2 = Range(1, 2) + print("r2 =", r2) + + r3 = Range(1, 2, 3) + print("r3 =", r3, ", len(r3) =", len(r3)) + + r3 = Range(*(1, 2, 3)) + print("r3 =", r3, ", len(r3) =", len(r3)) + + r3 = Range(*[1, 2, 3]) + print("r3 =", r3, ", len(r3) =", len(r3)) + + print("r1.get(0) =", r1.get(0)) + try: + print("r2.get(2) =", r2.get(2)) + except Exception as e: + print(e) + + print("r3.get(0) =", r3.get(0)) + print("r3.get(1) =", r3.get(1)) + + print("r1[0] =", r1[0]) + try: + print("r2[2] =", r2[2]) + except Exception as e: + print(e) + + print("r3[0] =", r3[0]) + print("r3[1] =", r3[1]) + + try: + r4 = Range(1, 2, 3, 4) + except Exception as e: + print(e) + + try: + r5 = Range(*(1, 2, 3, 4)) + except Exception as e: + print(e) + + ndr1 = NdRange(Range(1, 2)) + print("ndr1 =", ndr1) + + ndr2 = NdRange(Range(1, 2), Range(1, 1, 1)) + print("ndr2 =", ndr2) + + ndr3 = NdRange((1, 2)) + print("ndr3 =", ndr3) + + ndr4 = NdRange((1, 2), (1, 1, 1)) + print("ndr4 =", ndr4) diff --git a/numba_dpex/examples/debug/dpex_func.py b/numba_dpex/examples/debug/dpex_func.py index bc095c65af..4acc97d763 100644 --- a/numba_dpex/examples/debug/dpex_func.py +++ b/numba_dpex/examples/debug/dpex_func.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.func(debug=True) @@ -24,7 +25,7 @@ def driver(a, b, c, global_size): print("a = ", a) print("b = ", b) print("c = ", c) - kernel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + kernel_sum[Range(global_size)](a, b, c) print("a + b = ", c) diff --git a/numba_dpex/examples/debug/side-by-side-2.py b/numba_dpex/examples/debug/side-by-side-2.py index 119a6fd7dc..4c9797a856 100644 --- a/numba_dpex/examples/debug/side-by-side-2.py +++ b/numba_dpex/examples/debug/side-by-side-2.py @@ -9,6 +9,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range def common_loop_body(i, a, b): @@ -50,7 +51,7 @@ def numba_func_driver(a, b, c): def dpex_func_driver(a, b, c): device = dpctl.select_default_device() with dpctl.device_context(device): - kernel[len(c), dpex.DEFAULT_LOCAL_SIZE](a, b, c) + kernel[Range(len(c))](a, b, c) @dpex.kernel(debug=True) diff --git a/numba_dpex/examples/debug/side-by-side.py b/numba_dpex/examples/debug/side-by-side.py index d915c1c886..9f7c0db66a 100644 --- a/numba_dpex/examples/debug/side-by-side.py +++ b/numba_dpex/examples/debug/side-by-side.py @@ -9,6 +9,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range def common_loop_body(param_a, param_b): @@ -48,7 +49,7 @@ def numba_func_driver(a, b, c): def dpex_func_driver(a, b, c): device = dpctl.select_default_device() with dpctl.device_context(device): - kernel[len(c), dpex.DEFAULT_LOCAL_SIZE](a, b, c) + kernel[Range(len(c))](a, b, c) @dpex.kernel(debug=True) diff --git a/numba_dpex/examples/debug/simple_dpex_func.py b/numba_dpex/examples/debug/simple_dpex_func.py index fbd57349d8..976430dd11 100644 --- a/numba_dpex/examples/debug/simple_dpex_func.py +++ b/numba_dpex/examples/debug/simple_dpex_func.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.func(debug=True) @@ -27,6 +28,6 @@ def kernel_sum(a_in_kernel, b_in_kernel, c_in_kernel): device = dpctl.select_default_device() with dpctl.device_context(device): - kernel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + kernel_sum[Range(global_size)](a, b, c) print("Done...") diff --git a/numba_dpex/examples/debug/simple_sum.py b/numba_dpex/examples/debug/simple_sum.py index 5a0ea67e0f..e55d1328ff 100644 --- a/numba_dpex/examples/debug/simple_sum.py +++ b/numba_dpex/examples/debug/simple_sum.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.kernel(debug=True) @@ -23,6 +24,6 @@ def data_parallel_sum(a, b, c): device = dpctl.select_default_device() with dpctl.device_context(device): - data_parallel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + data_parallel_sum[Range(global_size)](a, b, c) print("Done...") diff --git a/numba_dpex/examples/debug/sum.py b/numba_dpex/examples/debug/sum.py index ec44fff306..72cca927b2 100644 --- a/numba_dpex/examples/debug/sum.py +++ b/numba_dpex/examples/debug/sum.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.kernel(debug=True) @@ -20,7 +21,7 @@ def driver(a, b, c, global_size): print("before : ", a) print("before : ", b) print("before : ", c) - data_parallel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + data_parallel_sum[Range(global_size)](a, b, c) print("after : ", c) diff --git a/numba_dpex/examples/debug/sum_local_vars.py b/numba_dpex/examples/debug/sum_local_vars.py index 72ec60a1fb..3d3b2a9a9c 100644 --- a/numba_dpex/examples/debug/sum_local_vars.py +++ b/numba_dpex/examples/debug/sum_local_vars.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.kernel(debug=True) @@ -25,6 +26,6 @@ def data_parallel_sum(a, b, c): device = dpctl.select_default_device() with dpctl.device_context(device): - data_parallel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + data_parallel_sum[Range(global_size)](a, b, c) print("Done...") diff --git a/numba_dpex/examples/debug/sum_local_vars_revive.py b/numba_dpex/examples/debug/sum_local_vars_revive.py index f50e22f663..386d54ee77 100644 --- a/numba_dpex/examples/debug/sum_local_vars_revive.py +++ b/numba_dpex/examples/debug/sum_local_vars_revive.py @@ -6,6 +6,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.func @@ -31,6 +32,6 @@ def data_parallel_sum(a, b, c): device = dpctl.select_default_device() with dpctl.device_context(device): - data_parallel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c) + data_parallel_sum[Range(global_size)](a, b, c) print("Done...") diff --git a/numba_dpex/examples/kernel/atomic_op.py b/numba_dpex/examples/kernel/atomic_op.py index 2e10f7cc18..653fbc15fe 100644 --- a/numba_dpex/examples/kernel/atomic_op.py +++ b/numba_dpex/examples/kernel/atomic_op.py @@ -5,6 +5,7 @@ import dpnp as np import numba_dpex as ndpex +from numba_dpex.core.kernel_interface.utils import Range @ndpex.kernel @@ -20,7 +21,7 @@ def main(): print("Using device ...") print(a.device) - atomic_reduction[N](a) + atomic_reduction[Range(N)](a) print("Reduction sum =", a[0]) print("Done...") diff --git a/numba_dpex/examples/kernel/black_scholes.py b/numba_dpex/examples/kernel/black_scholes.py index 3f6e9c5bd6..75be6d3b0d 100644 --- a/numba_dpex/examples/kernel/black_scholes.py +++ b/numba_dpex/examples/kernel/black_scholes.py @@ -8,6 +8,7 @@ import dpnp as np import numba_dpex as ndpx +from numba_dpex.core.kernel_interface.utils import Range # Stock price range S0L = 10.0 @@ -94,7 +95,9 @@ def main(): print("Using device ...") print(price.device) - kernel_black_scholes[NOPT](price, strike, t, rate, volatility, call, put) + kernel_black_scholes[Range(NOPT)]( + price, strike, t, rate, volatility, call, put + ) print("Call:", call) print("Put:", put) diff --git a/numba_dpex/examples/kernel/device_func.py b/numba_dpex/examples/kernel/device_func.py index 1c6fe52d39..80089a70fb 100644 --- a/numba_dpex/examples/kernel/device_func.py +++ b/numba_dpex/examples/kernel/device_func.py @@ -6,6 +6,7 @@ import numba_dpex as ndpex from numba_dpex import float32, int32, int64 +from numba_dpex.core.kernel_interface.utils import Range # Array size N = 10 @@ -69,7 +70,7 @@ def test1(): print("A=", a) try: - a_kernel_function[N](a, b) + a_kernel_function[Range(N)](a, b) except Exception as err: print(err) print("B=", b) @@ -87,7 +88,7 @@ def test2(): print("A=", a) try: - a_kernel_function_int32[N](a, b) + a_kernel_function_int32[Range(N)](a, b) except Exception as err: print(err) print("B=", b) @@ -105,7 +106,7 @@ def test3(): print("A=", a) try: - a_kernel_function_int32_float32[N](a, b) + a_kernel_function_int32_float32[Range(N)](a, b) except Exception as err: print(err) print("B=", b) @@ -119,7 +120,7 @@ def test3(): print("A=", a) try: - a_kernel_function_int32_float32[N](a, b) + a_kernel_function_int32_float32[Range(N)](a, b) except Exception as err: print(err) print("B=", b) @@ -134,7 +135,7 @@ def test3(): print("A=", a) try: - a_kernel_function_int32_float32[N](a, b) + a_kernel_function_int32_float32[Range(N)](a, b) except Exception as err: print(err) print("B=", b) diff --git a/numba_dpex/examples/kernel/interpolation.py b/numba_dpex/examples/kernel/interpolation.py index 7568ad60e7..3aa3c91765 100644 --- a/numba_dpex/examples/kernel/interpolation.py +++ b/numba_dpex/examples/kernel/interpolation.py @@ -7,6 +7,7 @@ from numpy.testing import assert_almost_equal import numba_dpex as ndpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range # Interpolation domain XLO = 10.0 @@ -114,9 +115,13 @@ def main(): print("Using device ...") print(xp.device) - global_range = (N_POINTS // N_POINTS_PER_WORK_ITEM,) - local_range = (LOCAL_SIZE,) - kernel_polynomial[global_range, local_range](xp, yp, COEFFICIENTS) + global_range = Range( + N_POINTS // N_POINTS_PER_WORK_ITEM, + ) + local_range = Range( + LOCAL_SIZE, + ) + kernel_polynomial[NdRange(global_range, local_range)](xp, yp, COEFFICIENTS) # Copy results back to the host nyp = np.asnumpy(yp) diff --git a/numba_dpex/examples/kernel/kernel_private_memory.py b/numba_dpex/examples/kernel/kernel_private_memory.py index 089f8b41d4..3219281f7c 100644 --- a/numba_dpex/examples/kernel/kernel_private_memory.py +++ b/numba_dpex/examples/kernel/kernel_private_memory.py @@ -8,6 +8,7 @@ from numba import float32 import numba_dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range def private_memory(): @@ -39,9 +40,9 @@ def private_memory_kernel(A): print("Using device ...") device.print_device_info() - global_range = (N,) - local_range = (N,) - private_memory_kernel[global_range, local_range](arr) + global_range = Range(N) + local_range = Range(N) + private_memory_kernel[NdRange(global_range, local_range)](arr) arr_out = dpt.asnumpy(arr) np.testing.assert_allclose(orig * 2, arr_out) diff --git a/numba_dpex/examples/kernel/kernel_specialization.py b/numba_dpex/examples/kernel/kernel_specialization.py index a3cd7fa759..e1aff12c23 100644 --- a/numba_dpex/examples/kernel/kernel_specialization.py +++ b/numba_dpex/examples/kernel/kernel_specialization.py @@ -11,6 +11,7 @@ InvalidKernelSpecializationError, MissingSpecializationError, ) +from numba_dpex.core.kernel_interface.utils import Range # Similar to Numba, numba-dpex supports eager compilation of functions. The # following examples demonstrate the feature for numba_dpex.kernel and presents @@ -38,7 +39,7 @@ def data_parallel_sum(a, b, c): b = dpt.ones(1024, dtype=dpt.int64) c = dpt.zeros(1024, dtype=dpt.int64) -data_parallel_sum[1024](a, b, c) +data_parallel_sum[Range(1024)](a, b, c) npc = dpt.asnumpy(c) npc_expected = np.full(1024, 2, dtype=np.int64) @@ -65,7 +66,7 @@ def data_parallel_sum2(a, b, c): b = dpt.ones(1024, dtype=dpt.int64) c = dpt.zeros(1024, dtype=dpt.int64) -data_parallel_sum2[1024](a, b, c) +data_parallel_sum2[Range(1024)](a, b, c) npc = dpt.asnumpy(c) npc_expected = np.full(1024, 2, dtype=np.int64) @@ -76,7 +77,7 @@ def data_parallel_sum2(a, b, c): b = dpt.ones(1024, dtype=dpt.float32) c = dpt.zeros(1024, dtype=dpt.float32) -data_parallel_sum2[1024](a, b, c) +data_parallel_sum2[Range(1024)](a, b, c) npc = dpt.asnumpy(c) npc_expected = np.full(1024, 2, dtype=np.float32) @@ -94,7 +95,7 @@ def data_parallel_sum2(a, b, c): c = dpt.zeros(1024, dtype=dpt.int32) try: - data_parallel_sum[1024](a, b, c) + data_parallel_sum[Range(1024)](a, b, c) except MissingSpecializationError as mse: print(mse) @@ -128,3 +129,5 @@ def data_parallel_sum2(a, b, c): "strings." ) print(e) + +print("Done...") diff --git a/numba_dpex/examples/kernel/matmul.py b/numba_dpex/examples/kernel/matmul.py index a40ccc207b..5fd8e44832 100644 --- a/numba_dpex/examples/kernel/matmul.py +++ b/numba_dpex/examples/kernel/matmul.py @@ -9,6 +9,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range @dpex.kernel @@ -30,13 +31,13 @@ def gemm(a, b, c): Y = 16 global_size = X, X -griddim = X, X -blockdim = Y, Y +griddim = Range(X, X) +blockdim = Range(Y, Y) def driver(a, b, c): # Invoke the kernel - gemm[griddim, blockdim](a, b, c) + gemm[NdRange(griddim, blockdim)](a, b, c) def main(): diff --git a/numba_dpex/examples/kernel/pairwise_distance.py b/numba_dpex/examples/kernel/pairwise_distance.py index 30d940a871..da4822c64f 100644 --- a/numba_dpex/examples/kernel/pairwise_distance.py +++ b/numba_dpex/examples/kernel/pairwise_distance.py @@ -12,6 +12,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range parser = argparse.ArgumentParser( description="Program to compute pairwise distance" @@ -25,9 +26,9 @@ args = parser.parse_args() # Global work size is equal to the number of points -global_size = (args.n,) +global_size = Range(args.n) # Local Work size is optional -local_size = (args.l,) +local_size = Range(args.l) X = np.random.random((args.n, args.d)).astype(np.single) D = np.empty((args.n, args.n), dtype=np.single) @@ -65,7 +66,7 @@ def driver(): for repeat in range(args.r): start = time() - pairwise_distance[global_size, local_size]( + pairwise_distance[NdRange(global_size, local_size)]( x_ndarray, d_ndarray, X.shape[0], X.shape[1] ) end = time() diff --git a/numba_dpex/examples/kernel/scan.py b/numba_dpex/examples/kernel/scan.py index 6ee4056fbb..13374bbf4b 100644 --- a/numba_dpex/examples/kernel/scan.py +++ b/numba_dpex/examples/kernel/scan.py @@ -7,6 +7,7 @@ import dpnp as np import numba_dpex as ndpx +from numba_dpex.core.kernel_interface.utils import Range # 1D array size N = 64 @@ -56,7 +57,7 @@ def main(): print("Using device ...") print(arr.device) - kernel_hillis_steele_scan[N](arr) + kernel_hillis_steele_scan[Range(N)](arr) # the output should be [0, 1, 3, 6, ...] arr_np = np.asnumpy(arr) diff --git a/numba_dpex/examples/kernel/select_device_for_kernel.py b/numba_dpex/examples/kernel/select_device_for_kernel.py index 7c08d7e9eb..fbe1f27bd1 100644 --- a/numba_dpex/examples/kernel/select_device_for_kernel.py +++ b/numba_dpex/examples/kernel/select_device_for_kernel.py @@ -9,6 +9,7 @@ import numpy as np import numba_dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range """ We support passing arrays of two types to a @numba_dpex.kernel decorated @@ -86,7 +87,7 @@ def select_device_ndarray(N): default_device = dpctl.select_default_device() with numba_dpex.offload_to_sycl_device(default_device.filter_string): - sum_kernel[(N,), (1,)](a, b, got) + sum_kernel[NdRange(Range(N), Range(1))](a, b, got) expected = a + b @@ -110,7 +111,7 @@ def select_device_SUAI(N): # Users don't need to specify where the computation will # take place. It will be inferred from data. - sum_kernel[(N,), (1,)](da, db, dc) + sum_kernel[NdRange(Range(N), Range(1))](da, db, dc) dc.usm_data.copy_to_host(got.reshape((-1)).view("|u1")) diff --git a/numba_dpex/examples/kernel/sum_reduction_ocl.py b/numba_dpex/examples/kernel/sum_reduction_ocl.py index 9ab19bdebd..03ddecdd0f 100644 --- a/numba_dpex/examples/kernel/sum_reduction_ocl.py +++ b/numba_dpex/examples/kernel/sum_reduction_ocl.py @@ -7,6 +7,7 @@ from numba import int32 import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range @dpex.kernel @@ -49,9 +50,9 @@ def sum_reduce(A): partial_sums = dpt.zeros(nb_work_groups, dtype=A.dtype, device=A.device) - gs = (global_size,) - ls = (work_group_size,) - sum_reduction_kernel[gs, ls](A, partial_sums) + gs = Range(global_size) + ls = Range(work_group_size) + sum_reduction_kernel[NdRange(gs, ls)](A, partial_sums) final_sum = 0 # calculate the final sum in HOST diff --git a/numba_dpex/examples/kernel/sum_reduction_recursive_ocl.py b/numba_dpex/examples/kernel/sum_reduction_recursive_ocl.py index 40183c5931..b90a985df0 100644 --- a/numba_dpex/examples/kernel/sum_reduction_recursive_ocl.py +++ b/numba_dpex/examples/kernel/sum_reduction_recursive_ocl.py @@ -13,6 +13,7 @@ from numba import int32 import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range @dpex.kernel @@ -58,13 +59,15 @@ def sum_recursive_reduction(size, group_size, Dinp, Dpartial_sums): nb_work_groups += 1 passed_size = nb_work_groups * group_size - gr = (passed_size,) - lr = (group_size,) + gr = Range(passed_size) + lr = Range(group_size) - sum_reduction_kernel[gr, lr](Dinp, size, Dpartial_sums) + sum_reduction_kernel[NdRange(gr, lr)](Dinp, size, Dpartial_sums) if nb_work_groups <= group_size: - sum_reduction_kernel[lr, lr](Dpartial_sums, nb_work_groups, Dinp) + sum_reduction_kernel[NdRange(lr, lr)]( + Dpartial_sums, nb_work_groups, Dinp + ) result = int(Dinp[0]) else: result = sum_recursive_reduction( diff --git a/numba_dpex/examples/kernel/vector_sum.py b/numba_dpex/examples/kernel/vector_sum.py index cb1b9fa2bb..40ccc268ba 100644 --- a/numba_dpex/examples/kernel/vector_sum.py +++ b/numba_dpex/examples/kernel/vector_sum.py @@ -6,6 +6,7 @@ import numpy.testing as testing import numba_dpex as ndpx +from numba_dpex.core.kernel_interface.utils import Range # Data parallel kernel implementing vector sum @@ -18,7 +19,7 @@ def kernel_vector_sum(a, b, c): # Utility function for printing and testing def driver(a, b, c, global_size): - kernel_vector_sum[global_size](a, b, c) + kernel_vector_sum[Range(global_size)](a, b, c) a_np = dpnp.asnumpy(a) # Copy dpnp array a to NumPy array a_np b_np = dpnp.asnumpy(b) # Copy dpnp array b to NumPy array b_np diff --git a/numba_dpex/examples/kernel/vector_sum2D.py b/numba_dpex/examples/kernel/vector_sum2D.py index 089721b7c1..5547698df8 100644 --- a/numba_dpex/examples/kernel/vector_sum2D.py +++ b/numba_dpex/examples/kernel/vector_sum2D.py @@ -9,6 +9,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import NdRange, Range @dpex.kernel @@ -29,7 +30,7 @@ def main(): # Array dimensions X = 8 Y = 8 - global_size = X, Y + global_size = Range(X, Y) a = np.arange(X * Y, dtype=np.float32).reshape(X, Y) b = np.arange(X * Y, dtype=np.float32).reshape(X, Y) @@ -48,8 +49,8 @@ def main(): print("Using device ...") device.print_device_info() + print("Running kernel ...") driver(a_dpt, b_dpt, c_dpt, global_size) - c_out = dpt.asnumpy(c_dpt) assert np.allclose(c, c_out) diff --git a/numba_dpex/examples/sum_reduction.py b/numba_dpex/examples/sum_reduction.py index cecafa5603..e7a47c65e7 100644 --- a/numba_dpex/examples/sum_reduction.py +++ b/numba_dpex/examples/sum_reduction.py @@ -8,6 +8,7 @@ import numpy as np import numba_dpex as dpex +from numba_dpex.core.kernel_interface.utils import Range @dpex.kernel @@ -34,7 +35,7 @@ def sum_reduce(A): with dpctl.device_context(device): while total > 1: global_size = total // 2 - sum_reduction_kernel[global_size](A, R, global_size) + sum_reduction_kernel[Range(global_size)](A, R, global_size) total = total // 2 return R[0] diff --git a/numba_dpex/tests/kernel_tests/test_kernel_launch_params.py b/numba_dpex/tests/kernel_tests/test_kernel_launch_params.py index fa7658623d..4e6d697329 100644 --- a/numba_dpex/tests/kernel_tests/test_kernel_launch_params.py +++ b/numba_dpex/tests/kernel_tests/test_kernel_launch_params.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import dpctl +import dpctl.tensor as dpt import pytest import numba_dpex as dpex from numba_dpex.core.exceptions import ( IllegalRangeValueError, InvalidKernelLaunchArgsError, + UnknownGlobalRangeError, ) @@ -37,25 +40,19 @@ def test_1D_global_range_as_list(): assert k._local_range is None -def test_1D_global_range_and_1D_local_range(): - k = vecadd[10, 10] - assert k._global_range == [10] - assert k._local_range == [10] - - -def test_1D_global_range_and_1D_local_range2(): +def test_1D_global_range_and_1D_local_range1(): k = vecadd[[10, 10]] assert k._global_range == [10] assert k._local_range == [10] -def test_1D_global_range_and_1D_local_range3(): +def test_1D_global_range_and_1D_local_range2(): k = vecadd[(10,), (10,)] assert k._global_range == [10] assert k._local_range == [10] -def test_2D_global_range_and_2D_local_range(): +def test_2D_global_range_and_2D_local_range1(): k = vecadd[(10, 10), (10, 10)] assert k._global_range == [10, 10] assert k._local_range == [10, 10] @@ -79,7 +76,7 @@ def test_2D_global_range_and_2D_local_range4(): assert k._local_range == [10, 10] -def test_deprecation_warning_for_empty_local_range(): +def test_deprecation_warning_for_empty_local_range1(): with pytest.deprecated_call(): k = vecadd[[10, 10], []] assert k._global_range == [10, 10] @@ -93,12 +90,45 @@ def test_deprecation_warning_for_empty_local_range2(): assert k._local_range is None -def test_illegal_kernel_launch_arg(): +def test_ambiguous_kernel_launch_params(): + with pytest.deprecated_call(): + k = vecadd[10, 10] + assert k._global_range == [10] + assert k._local_range == [10] + + with pytest.deprecated_call(): + k = vecadd[(10, 10)] + assert k._global_range == [10] + assert k._local_range == [10] + + with pytest.deprecated_call(): + k = vecadd[((10), (10))] + assert k._global_range == [10] + assert k._local_range == [10] + + +def test_unknown_global_range_error(): + device = dpctl.select_default_device() + a = dpt.ones(10, dtype=dpt.int16, device=device) + b = dpt.ones(10, dtype=dpt.int16, device=device) + c = dpt.zeros(10, dtype=dpt.int16, device=device) + try: + vecadd(a, b, c) + except UnknownGlobalRangeError as e: + assert "No global range" in e.message + + +def test_illegal_kernel_launch_arg1(): + with pytest.raises(InvalidKernelLaunchArgsError): + vecadd[()] + + +def test_illegal_kernel_launch_arg2(): with pytest.raises(InvalidKernelLaunchArgsError): vecadd[10, 10, []] -def test_illegal_range_error(): +def test_illegal_range_error1(): with pytest.raises(IllegalRangeValueError): vecadd[[], []] @@ -111,3 +141,7 @@ def test_illegal_range_error2(): def test_illegal_range_error3(): with pytest.raises(IllegalRangeValueError): vecadd[(), 10] + + +if __name__ == "__main__": + test_unknown_global_range_error() diff --git a/numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py b/numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py index 9211a5366b..aa4e9e33b6 100644 --- a/numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py +++ b/numba_dpex/tests/kernel_tests/test_ndrange_exceptions.py @@ -5,10 +5,7 @@ import pytest import numba_dpex as ndpx -from numba_dpex.core.exceptions import ( - UnmatchedNumberOfRangeDimsError, - UnsupportedGroupWorkItemSizeError, -) +from numba_dpex.core.kernel_interface.utils import NdRange # Data parallel kernel implementing vector sum @@ -19,13 +16,13 @@ def kernel_vector_sum(a, b, c): @pytest.mark.parametrize( - "error, ndrange", + "error, ranges", [ - (UnmatchedNumberOfRangeDimsError, ((2, 2), (1, 1, 1))), - (UnsupportedGroupWorkItemSizeError, ((3, 3, 3), (2, 2, 2))), + (TypeError, ((2, 2), ("a", 1, 1))), + (TypeError, ((3, 3, 3, 3), (2, 2, 2))), ], ) -def test_ndrange_config_error(error, ndrange): +def test_ndrange_config_error(error, ranges): """Test if a exception is raised when calling a ndrange kernel with unspported arguments. """ @@ -35,4 +32,5 @@ def test_ndrange_config_error(error, ndrange): c = dpt.zeros(1024, dtype=dpt.int64) with pytest.raises(error): - kernel_vector_sum[ndrange](a, b, c) + range = NdRange(ranges[0], ranges[1]) + kernel_vector_sum[range](a, b, c)