Skip to content

Commit f536bd1

Browse files
committed
A better implementation for the Range/NdRange class, with examples
Better deprecation warnings added Following exact sycl::range/nd_range specification for kernel lauch parameters Default cache size is set to 128, like numba
1 parent e66c0fc commit f536bd1

File tree

7 files changed

+237
-172
lines changed

7 files changed

+237
-172
lines changed

numba_dpex/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __getattr__(name):
9999
ENABLE_CACHE = _readenv("NUMBA_DPEX_ENABLE_CACHE", int, 1)
100100
# Capacity of the cache, execute it like:
101101
# NUMBA_DPEX_CACHE_SIZE=20 python <code>
102-
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 10)
102+
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 128)
103103

104104
TESTING_SKIP_NO_DPNP = _readenv("NUMBA_DPEX_TESTING_SKIP_NO_DPNP", int, 0)
105105
TESTING_SKIP_NO_DEBUGGING = _readenv(

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 52 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@
3333
)
3434
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
3535
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
36-
from numba_dpex.core.kernel_interface.utils import Ranges
36+
from numba_dpex.core.kernel_interface.utils import NdRange, Range
3737
from numba_dpex.core.types import USMNdArray
3838

39-
simplefilter("always", DeprecationWarning)
40-
4139

4240
def get_ordered_arg_access_types(pyfunc, access_types):
4341
"""Deprecated and to be removed in next release."""
@@ -445,56 +443,6 @@ def _determine_kernel_launch_queue(self, args, argtypes):
445443
else:
446444
raise ExecutionQueueInferenceError(self.kernel_name)
447445

448-
def _raise_invalid_kernel_enqueue_args(self):
449-
error_message = (
450-
"Incorrect number of arguments for enqueuing numba_dpex.kernel. "
451-
"Usage: device_env, global size, local size. "
452-
"The local size argument is optional."
453-
)
454-
raise InvalidKernelLaunchArgsError(error_message)
455-
456-
def _ensure_valid_work_item_grid(self, val):
457-
if not isinstance(val, (tuple, list, int)):
458-
error_message = (
459-
"Cannot create work item dimension from provided argument"
460-
)
461-
raise ValueError(error_message)
462-
463-
if isinstance(val, int):
464-
val = [val]
465-
466-
# TODO: we need some way to check the max dimensions
467-
"""
468-
if len(val) > device_env.get_max_work_item_dims():
469-
error_message = ("Unsupported number of work item dimensions ")
470-
raise ValueError(error_message)
471-
"""
472-
473-
return list(
474-
val[::-1]
475-
) # reversing due to sycl and opencl interop kernel range mismatch semantic
476-
477-
def _ensure_valid_work_group_size(self, val, work_item_grid):
478-
if not isinstance(val, (tuple, list, int)):
479-
error_message = (
480-
"Cannot create work item dimension from provided argument"
481-
)
482-
raise ValueError(error_message)
483-
484-
if isinstance(val, int):
485-
val = [val]
486-
487-
if len(val) != len(work_item_grid):
488-
error_message = (
489-
"Unsupported number of work item dimensions, "
490-
+ "dimensions of global and local work items has to be the same "
491-
)
492-
raise IllegalRangeValueError(error_message)
493-
494-
return list(
495-
val[::-1]
496-
) # reversing due to sycl and opencl interop kernel range mismatch semantic
497-
498446
def __getitem__(self, args):
499447
"""Mimic's ``numba.cuda`` square-bracket notation for configuring the
500448
global_range and local_range settings when launching a kernel on a
@@ -522,8 +470,11 @@ def __getitem__(self, args):
522470
global_range and local_range attributes initialized.
523471
524472
"""
525-
526-
if isinstance(args, Ranges):
473+
if isinstance(args, Range):
474+
# we need inversions, see github issue #889
475+
self._global_range = list(args)[::-1]
476+
elif isinstance(args, NdRange):
477+
# we need inversions, see github issue #889
527478
self._global_range = list(args.global_range)[::-1]
528479
self._local_range = list(args.local_range)[::-1]
529480
else:
@@ -534,44 +485,73 @@ def __getitem__(self, args):
534485
and isinstance(args[1], int)
535486
):
536487
warn(
537-
"Ambiguous kernel launch paramters. "
538-
+ "If your data have dimensions > 1, "
539-
+ "include a default/empty local_range. "
540-
+ "i.e. <function>[(M,N), numba_dpex.DEFAULT_LOCAL_RANGE](<params>), "
488+
"Ambiguous kernel launch paramters. If your data have "
489+
+ "dimensions > 1, include a default/empty local_range:\n"
490+
+ " <function>[(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE](<params>)\n"
541491
+ "otherwise your code might produce erroneous results.",
542492
DeprecationWarning,
493+
stacklevel=2,
543494
)
544495
self._global_range = [args[0]]
545496
self._local_range = [args[1]]
546497
return self
547498

548-
if not isinstance(args, Iterable):
549-
args = [args]
499+
warn(
500+
"The current syntax for specification of kernel lauch "
501+
+ "parameters is deprecated. Users should set the kernel "
502+
+ "parameters through Range/NdRange classes.\n"
503+
+ "Example:\n"
504+
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
505+
+ " # for global range only\n"
506+
+ " <function>[Range(X,Y)](<parameters>)\n"
507+
+ " # or,\n"
508+
+ " # for both global and local ranges\n"
509+
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
510+
DeprecationWarning,
511+
stacklevel=2,
512+
)
550513

551-
ls = None
514+
args = [args] if not isinstance(args, Iterable) else args
552515
nargs = len(args)
516+
553517
# Check if the kernel enquing arguments are sane
554518
if nargs < 1 or nargs > 2:
555-
self._raise_invalid_kernel_enqueue_args()
519+
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)
556520

557-
gs = self._ensure_valid_work_item_grid(args[0])
521+
g_range = (
522+
[args[0]] if not isinstance(args[0], Iterable) else args[0]
523+
)
558524
# If the optional local size argument is provided
525+
l_range = None
559526
if nargs == 2:
560527
if args[1] != []:
561-
ls = self._ensure_valid_work_group_size(args[1], gs)
528+
l_range = (
529+
[args[1]]
530+
if not isinstance(args[1], Iterable)
531+
else args[1]
532+
)
562533
else:
563534
warn(
564-
"Empty local_range calls will be deprecated in the future.",
535+
"Empty local_range calls are deprecated. Please use Range/NdRange "
536+
+ "to specify the kernel launch parameters:\n"
537+
+ "Example:\n"
538+
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
539+
+ " # for global range only\n"
540+
+ " <function>[Range(X,Y)](<parameters>)\n"
541+
+ " # or,\n"
542+
+ " # for both global and local ranges\n"
543+
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
565544
DeprecationWarning,
545+
stacklevel=2,
566546
)
567547

568-
self._global_range = list(gs)[::-1]
569-
self._local_range = list(ls)[::-1] if ls else None
548+
if len(g_range) < 1:
549+
raise IllegalRangeValueError(kernel_name=self.kernel_name)
550+
551+
# we need inversions, see github issue #889
552+
self._global_range = list(g_range)[::-1]
553+
self._local_range = list(l_range)[::-1] if l_range else None
570554

571-
if self._global_range == [] and self._local_range is None:
572-
raise IllegalRangeValueError(
573-
"Illegal range values for kernel launch parameters."
574-
)
575555
return self
576556

577557
def _check_ranges(self, device):

0 commit comments

Comments
 (0)