Skip to content

Commit a62ff1a

Browse files
authored
Merge pull request #1231 from IntelPython/feature/clean_experimental_launcher
Clean up experimental launcher
2 parents 08edbb1 + 2b52a19 commit a62ff1a

File tree

3 files changed

+243
-310
lines changed

3 files changed

+243
-310
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
131131
callargs_ptrs=callargs_ptrs,
132132
args_list=args_list,
133133
args_ty_list=args_ty_list,
134-
datamodel_mgr=dpex_dmm,
135134
)
136135

137136
return _KernelArgs(

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from llvmlite import ir as llvmir
66
from numba.core import cgutils, types
77

8-
from numba_dpex import utils
8+
from numba_dpex import config, utils
99
from numba_dpex.core.runtime.context import DpexRTContext
1010
from numba_dpex.core.types import DpnpNdArray
1111
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
@@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range):
361361

362362
return self.builder.bitcast(range_list, intp_ptr_t)
363363

364+
def submit_kernel(
365+
self,
366+
kernel_ref: llvmir.CallInstr,
367+
queue_ref: llvmir.PointerType,
368+
kernel_args: list,
369+
ty_kernel_args: list,
370+
global_range_extents: list,
371+
local_range_extents: list,
372+
):
373+
if config.DEBUG_KERNEL_LAUNCHER:
374+
cgutils.printf(
375+
self.builder,
376+
"DPEX-DEBUG: Populating kernel args and arg type arrays.\n",
377+
)
378+
379+
num_flattened_kernel_args = self.get_num_flattened_kernel_args(
380+
kernel_argtys=ty_kernel_args,
381+
)
382+
383+
# Create LLVM values for the kernel args list and kernel arg types list
384+
args_list = self.allocate_kernel_arg_array(num_flattened_kernel_args)
385+
386+
args_ty_list = self.allocate_kernel_arg_ty_array(
387+
num_flattened_kernel_args
388+
)
389+
390+
kernel_args_ptrs = []
391+
for arg in kernel_args:
392+
ptr = self.builder.alloca(arg.type)
393+
self.builder.store(arg, ptr)
394+
kernel_args_ptrs.append(ptr)
395+
396+
# Populate the args_list and the args_ty_list LLVM arrays
397+
self.populate_kernel_args_and_args_ty_arrays(
398+
callargs_ptrs=kernel_args_ptrs,
399+
kernel_argtys=ty_kernel_args,
400+
args_list=args_list,
401+
args_ty_list=args_ty_list,
402+
)
403+
404+
if config.DEBUG_KERNEL_LAUNCHER:
405+
cgutils.printf(self._builder, "DPEX-DEBUG: Submit kernel.\n")
406+
407+
return self.submit_sycl_kernel(
408+
sycl_kernel_ref=kernel_ref,
409+
sycl_queue_ref=queue_ref,
410+
total_kernel_args=num_flattened_kernel_args,
411+
arg_list=args_list,
412+
arg_ty_list=args_ty_list,
413+
global_range=global_range_extents,
414+
local_range=local_range_extents,
415+
wait_before_return=False,
416+
)
417+
364418
def submit_sycl_kernel(
365419
self,
366420
sycl_kernel_ref,
@@ -373,7 +427,7 @@ def submit_sycl_kernel(
373427
wait_before_return=True,
374428
) -> llvmir.PointerType(llvmir.IntType(8)):
375429
"""
376-
Submits the kernel to the specified queue, waits.
430+
Submits the kernel to the specified queue, waits by default.
377431
"""
378432
eref = None
379433
gr = self._create_sycl_range(global_range)
@@ -411,19 +465,34 @@ def submit_sycl_kernel(
411465
else:
412466
return eref
413467

468+
def get_num_flattened_kernel_args(
469+
self,
470+
kernel_argtys: tuple[types.Type, ...],
471+
):
472+
num_flattened_kernel_args = 0
473+
for arg_type in kernel_argtys:
474+
if isinstance(arg_type, DpnpNdArray):
475+
datamodel = self.context.data_model_manager.lookup(arg_type)
476+
num_flattened_kernel_args += datamodel.flattened_field_count
477+
elif arg_type in [types.complex64, types.complex128]:
478+
num_flattened_kernel_args += 2
479+
else:
480+
num_flattened_kernel_args += 1
481+
482+
return num_flattened_kernel_args
483+
414484
def populate_kernel_args_and_args_ty_arrays(
415485
self,
416486
kernel_argtys,
417487
callargs_ptrs,
418488
args_list,
419489
args_ty_list,
420-
datamodel_mgr,
421490
):
422491
kernel_arg_num = 0
423492
for arg_num, argtype in enumerate(kernel_argtys):
424493
llvm_val = callargs_ptrs[arg_num]
425494
if isinstance(argtype, DpnpNdArray):
426-
datamodel = datamodel_mgr.lookup(argtype)
495+
datamodel = self.context.data_model_manager.lookup(argtype)
427496
self.build_array_arg(
428497
array_val=llvm_val,
429498
array_data_model=datamodel,

0 commit comments

Comments
 (0)