33
33
)
34
34
from numba_dpex .core .kernel_interface .arg_pack_unpacker import Packer
35
35
from numba_dpex .core .kernel_interface .spirv_kernel import SpirvKernel
36
- from numba_dpex .core .kernel_interface .utils import Ranges
36
+ from numba_dpex .core .kernel_interface .utils import NdRange , Range
37
37
from numba_dpex .core .types import USMNdArray
38
38
39
- simplefilter ("always" , DeprecationWarning )
40
-
41
39
42
40
def get_ordered_arg_access_types (pyfunc , access_types ):
43
41
"""Deprecated and to be removed in next release."""
@@ -445,56 +443,6 @@ def _determine_kernel_launch_queue(self, args, argtypes):
445
443
else :
446
444
raise ExecutionQueueInferenceError (self .kernel_name )
447
445
448
- def _raise_invalid_kernel_enqueue_args (self ):
449
- error_message = (
450
- "Incorrect number of arguments for enqueuing numba_dpex.kernel. "
451
- "Usage: device_env, global size, local size. "
452
- "The local size argument is optional."
453
- )
454
- raise InvalidKernelLaunchArgsError (error_message )
455
-
456
- def _ensure_valid_work_item_grid (self , val ):
457
- if not isinstance (val , (tuple , list , int )):
458
- error_message = (
459
- "Cannot create work item dimension from provided argument"
460
- )
461
- raise ValueError (error_message )
462
-
463
- if isinstance (val , int ):
464
- val = [val ]
465
-
466
- # TODO: we need some way to check the max dimensions
467
- """
468
- if len(val) > device_env.get_max_work_item_dims():
469
- error_message = ("Unsupported number of work item dimensions ")
470
- raise ValueError(error_message)
471
- """
472
-
473
- return list (
474
- val [::- 1 ]
475
- ) # reversing due to sycl and opencl interop kernel range mismatch semantic
476
-
477
- def _ensure_valid_work_group_size (self , val , work_item_grid ):
478
- if not isinstance (val , (tuple , list , int )):
479
- error_message = (
480
- "Cannot create work item dimension from provided argument"
481
- )
482
- raise ValueError (error_message )
483
-
484
- if isinstance (val , int ):
485
- val = [val ]
486
-
487
- if len (val ) != len (work_item_grid ):
488
- error_message = (
489
- "Unsupported number of work item dimensions, "
490
- + "dimensions of global and local work items has to be the same "
491
- )
492
- raise IllegalRangeValueError (error_message )
493
-
494
- return list (
495
- val [::- 1 ]
496
- ) # reversing due to sycl and opencl interop kernel range mismatch semantic
497
-
498
446
def __getitem__ (self , args ):
499
447
"""Mimic's ``numba.cuda`` square-bracket notation for configuring the
500
448
global_range and local_range settings when launching a kernel on a
@@ -522,8 +470,11 @@ def __getitem__(self, args):
522
470
global_range and local_range attributes initialized.
523
471
524
472
"""
525
-
526
- if isinstance (args , Ranges ):
473
+ if isinstance (args , Range ):
474
+ # we need inversions, see github issue #889
475
+ self ._global_range = list (args )[::- 1 ]
476
+ elif isinstance (args , NdRange ):
477
+ # we need inversions, see github issue #889
527
478
self ._global_range = list (args .global_range )[::- 1 ]
528
479
self ._local_range = list (args .local_range )[::- 1 ]
529
480
else :
@@ -534,44 +485,73 @@ def __getitem__(self, args):
534
485
and isinstance (args [1 ], int )
535
486
):
536
487
warn (
537
- "Ambiguous kernel launch paramters. "
538
- + "If your data have dimensions > 1, "
539
- + "include a default/empty local_range. "
540
- + "i.e. <function>[(M,N), numba_dpex.DEFAULT_LOCAL_RANGE](<params>), "
488
+ "Ambiguous kernel launch paramters. If your data have "
489
+ + "dimensions > 1, include a default/empty local_range:\n "
490
+ + " <function>[(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE](<params>)\n "
541
491
+ "otherwise your code might produce erroneous results." ,
542
492
DeprecationWarning ,
493
+ stacklevel = 2 ,
543
494
)
544
495
self ._global_range = [args [0 ]]
545
496
self ._local_range = [args [1 ]]
546
497
return self
547
498
548
- if not isinstance (args , Iterable ):
549
- args = [args ]
499
+ warn (
500
+ "The current syntax for specification of kernel lauch "
501
+ + "parameters is deprecated. Users should set the kernel "
502
+ + "parameters through Range/NdRange classes.\n "
503
+ + "Example:\n "
504
+ + " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n \n "
505
+ + " # for global range only\n "
506
+ + " <function>[Range(X,Y)](<parameters>)\n "
507
+ + " # or,\n "
508
+ + " # for both global and local ranges\n "
509
+ + " <function>[NdRange((X,Y), (P,Q))](<parameters>)" ,
510
+ DeprecationWarning ,
511
+ stacklevel = 2 ,
512
+ )
550
513
551
- ls = None
514
+ args = [ args ] if not isinstance ( args , Iterable ) else args
552
515
nargs = len (args )
516
+
553
517
# Check if the kernel enquing arguments are sane
554
518
if nargs < 1 or nargs > 2 :
555
- self ._raise_invalid_kernel_enqueue_args ( )
519
+ raise InvalidKernelLaunchArgsError ( kernel_name = self .kernel_name )
556
520
557
- gs = self ._ensure_valid_work_item_grid (args [0 ])
521
+ g_range = (
522
+ [args [0 ]] if not isinstance (args [0 ], Iterable ) else args [0 ]
523
+ )
558
524
# If the optional local size argument is provided
525
+ l_range = None
559
526
if nargs == 2 :
560
527
if args [1 ] != []:
561
- ls = self ._ensure_valid_work_group_size (args [1 ], gs )
528
+ l_range = (
529
+ [args [1 ]]
530
+ if not isinstance (args [1 ], Iterable )
531
+ else args [1 ]
532
+ )
562
533
else :
563
534
warn (
564
- "Empty local_range calls will be deprecated in the future." ,
535
+ "Empty local_range calls are deprecated. Please use Range/NdRange "
536
+ + "to specify the kernel launch parameters:\n "
537
+ + "Example:\n "
538
+ + " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n \n "
539
+ + " # for global range only\n "
540
+ + " <function>[Range(X,Y)](<parameters>)\n "
541
+ + " # or,\n "
542
+ + " # for both global and local ranges\n "
543
+ + " <function>[NdRange((X,Y), (P,Q))](<parameters>)" ,
565
544
DeprecationWarning ,
545
+ stacklevel = 2 ,
566
546
)
567
547
568
- self ._global_range = list (gs )[::- 1 ]
569
- self ._local_range = list (ls )[::- 1 ] if ls else None
548
+ if len (g_range ) < 1 :
549
+ raise IllegalRangeValueError (kernel_name = self .kernel_name )
550
+
551
+ # we need inversions, see github issue #889
552
+ self ._global_range = list (g_range )[::- 1 ]
553
+ self ._local_range = list (l_range )[::- 1 ] if l_range else None
570
554
571
- if self ._global_range == [] and self ._local_range is None :
572
- raise IllegalRangeValueError (
573
- "Illegal range values for kernel launch parameters."
574
- )
575
555
return self
576
556
577
557
def _check_ranges (self , device ):
0 commit comments