5
5
from llvmlite import ir as llvmir
6
6
from numba .core import cgutils , types
7
7
8
- from numba_dpex import utils
8
+ from numba_dpex import config , utils
9
9
from numba_dpex .core .runtime .context import DpexRTContext
10
10
from numba_dpex .core .types import DpnpNdArray
11
11
from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
@@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range):
361
361
362
362
return self .builder .bitcast (range_list , intp_ptr_t )
363
363
364
+ def submit_kernel (
365
+ self ,
366
+ kernel_ref : llvmir .CallInstr ,
367
+ queue_ref : llvmir .PointerType ,
368
+ kernel_args : list ,
369
+ ty_kernel_args : list ,
370
+ global_range_extents : list ,
371
+ local_range_extents : list ,
372
+ ):
373
+ if config .DEBUG_KERNEL_LAUNCHER :
374
+ cgutils .printf (
375
+ self .builder ,
376
+ "DPEX-DEBUG: Populating kernel args and arg type arrays.\n " ,
377
+ )
378
+
379
+ num_flattened_kernel_args = self .get_num_flattened_kernel_args (
380
+ kernel_argtys = ty_kernel_args ,
381
+ )
382
+
383
+ # Create LLVM values for the kernel args list and kernel arg types list
384
+ args_list = self .allocate_kernel_arg_array (num_flattened_kernel_args )
385
+
386
+ args_ty_list = self .allocate_kernel_arg_ty_array (
387
+ num_flattened_kernel_args
388
+ )
389
+
390
+ kernel_args_ptrs = []
391
+ for arg in kernel_args :
392
+ ptr = self .builder .alloca (arg .type )
393
+ self .builder .store (arg , ptr )
394
+ kernel_args_ptrs .append (ptr )
395
+
396
+ # Populate the args_list and the args_ty_list LLVM arrays
397
+ self .populate_kernel_args_and_args_ty_arrays (
398
+ callargs_ptrs = kernel_args_ptrs ,
399
+ kernel_argtys = ty_kernel_args ,
400
+ args_list = args_list ,
401
+ args_ty_list = args_ty_list ,
402
+ )
403
+
404
+ if config .DEBUG_KERNEL_LAUNCHER :
405
+ cgutils .printf (self ._builder , "DPEX-DEBUG: Submit kernel.\n " )
406
+
407
+ return self .submit_sycl_kernel (
408
+ sycl_kernel_ref = kernel_ref ,
409
+ sycl_queue_ref = queue_ref ,
410
+ total_kernel_args = num_flattened_kernel_args ,
411
+ arg_list = args_list ,
412
+ arg_ty_list = args_ty_list ,
413
+ global_range = global_range_extents ,
414
+ local_range = local_range_extents ,
415
+ wait_before_return = False ,
416
+ )
417
+
364
418
def submit_sycl_kernel (
365
419
self ,
366
420
sycl_kernel_ref ,
@@ -373,7 +427,7 @@ def submit_sycl_kernel(
373
427
wait_before_return = True ,
374
428
) -> llvmir .PointerType (llvmir .IntType (8 )):
375
429
"""
376
- Submits the kernel to the specified queue, waits.
430
+ Submits the kernel to the specified queue, waits by default .
377
431
"""
378
432
eref = None
379
433
gr = self ._create_sycl_range (global_range )
@@ -411,19 +465,34 @@ def submit_sycl_kernel(
411
465
else :
412
466
return eref
413
467
468
+ def get_num_flattened_kernel_args (
469
+ self ,
470
+ kernel_argtys : tuple [types .Type , ...],
471
+ ):
472
+ num_flattened_kernel_args = 0
473
+ for arg_type in kernel_argtys :
474
+ if isinstance (arg_type , DpnpNdArray ):
475
+ datamodel = self .context .data_model_manager .lookup (arg_type )
476
+ num_flattened_kernel_args += datamodel .flattened_field_count
477
+ elif arg_type in [types .complex64 , types .complex128 ]:
478
+ num_flattened_kernel_args += 2
479
+ else :
480
+ num_flattened_kernel_args += 1
481
+
482
+ return num_flattened_kernel_args
483
+
414
484
def populate_kernel_args_and_args_ty_arrays (
415
485
self ,
416
486
kernel_argtys ,
417
487
callargs_ptrs ,
418
488
args_list ,
419
489
args_ty_list ,
420
- datamodel_mgr ,
421
490
):
422
491
kernel_arg_num = 0
423
492
for arg_num , argtype in enumerate (kernel_argtys ):
424
493
llvm_val = callargs_ptrs [arg_num ]
425
494
if isinstance (argtype , DpnpNdArray ):
426
- datamodel = datamodel_mgr .lookup (argtype )
495
+ datamodel = self . context . data_model_manager .lookup (argtype )
427
496
self .build_array_arg (
428
497
array_val = llvm_val ,
429
498
array_data_model = datamodel ,
0 commit comments