Skip to content

Commit d1700ff

Browse files
author
Diptorup Deb
authored
Merge pull request #1230 from IntelPython/experimental/enable_overloads
Enables adding overload to DpexExpKernelTarget and fully inline them into the final module.
2 parents a62ff1a + 0af4414 commit d1700ff

11 files changed

+410
-24
lines changed

numba_dpex/core/codegen.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@ class SPIRVCodeLibrary(CPUCodeLibrary):
2929
def _optimize_functions(self, ll_module):
3030
pass
3131

32+
@property
33+
def inline_threshold(self):
34+
"""The inlining threshold value to be used to optimize the final library"""
35+
if hasattr(self, "_inline_threshold"):
36+
return self._inline_threshold
37+
else:
38+
return 0
39+
40+
@inline_threshold.setter
41+
def inline_threshold(self, value: int):
42+
"""Returns the current inlining threshold level for the library."""
43+
if value < 0 or value > 3:
44+
logging.warning(
45+
"Unsupported inline threshold. Set a value between 0 and 3"
46+
)
47+
self._inline_threshold = 0
48+
else:
49+
self._inline_threshold = value
50+
3251
def _optimize_final_module(self):
3352
# Run some lightweight optimization to simplify the module.
3453
pmb = ll.PassManagerBuilder()
@@ -43,12 +62,38 @@ def _optimize_final_module(self):
4362
)
4463

4564
pmb.disable_unit_at_a_time = False
65+
4666
if config.INLINE_THRESHOLD is not None:
47-
logging.warning(
48-
"Setting INLINE_THRESHOLD leads to very aggressive "
49-
+ "optimizations that may produce incorrect binary."
50-
)
51-
pmb.inlining_threshold = config.INLINE_THRESHOLD
67+
# Check if a decorator-level inline threshold was set and use that
68+
# instead of the global configuration.
69+
if (
70+
hasattr(self, "_inline_threshold")
71+
and self._inline_threshold > 0
72+
and self._inline_threshold <= 3
73+
):
74+
logging.warning(
75+
"Setting INLINE_THRESHOLD leads to very aggressive "
76+
+ "optimizations that may produce incorrect binary."
77+
)
78+
pmb.inlining_threshold = self._inline_threshold
79+
elif not hasattr(self, "_inline_threshold"):
80+
logging.warning(
81+
"Setting INLINE_THRESHOLD leads to very aggressive "
82+
+ "optimizations that may produce incorrect binary."
83+
)
84+
pmb.inlining_threshold = config.INLINE_THRESHOLD
85+
else:
86+
if (
87+
hasattr(self, "_inline_threshold")
88+
and self._inline_threshold > 0
89+
and self._inline_threshold <= 3
90+
):
91+
logging.warning(
92+
"Setting INLINE_THRESHOLD leads to very aggressive "
93+
+ "optimizations that may produce incorrect binary."
94+
)
95+
pmb.inlining_threshold = self._inline_threshold
96+
5297
pmb.disable_unroll_loops = True
5398
pmb.loop_vectorize = False
5499
pmb.slp_vectorize = False

numba_dpex/core/descriptor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
from numba.core.cpu import CPUTargetOptions
99
from numba.core.descriptors import TargetDescriptor
1010

11+
from numba_dpex import config
12+
1113
from .targets.dpjit_target import DPEX_TARGET_NAME, DpexTargetContext
1214
from .targets.kernel_target import (
1315
DPEX_KERNEL_TARGET_NAME,
16+
CompilationMode,
1417
DpexKernelTargetContext,
1518
DpexKernelTypingContext,
1619
)
@@ -40,13 +43,24 @@ class DpexTargetOptions(CPUTargetOptions):
4043
release_gil = _option_mapping("release_gil")
4144
no_compile = _option_mapping("no_compile")
4245
use_mlir = _option_mapping("use_mlir")
46+
inline_threshold = _option_mapping("inline_threshold")
47+
_compilation_mode = _option_mapping("_compilation_mode")
4348

4449
def finalize(self, flags, options):
4550
super().finalize(flags, options)
4651
_inherit_if_not_set(flags, options, "experimental", False)
4752
_inherit_if_not_set(flags, options, "release_gil", False)
4853
_inherit_if_not_set(flags, options, "no_compile", True)
4954
_inherit_if_not_set(flags, options, "use_mlir", False)
55+
if config.INLINE_THRESHOLD is not None:
56+
_inherit_if_not_set(
57+
flags, options, "inline_threshold", config.INLINE_THRESHOLD
58+
)
59+
else:
60+
_inherit_if_not_set(flags, options, "inline_threshold", 0)
61+
_inherit_if_not_set(
62+
flags, options, "_compilation_mode", CompilationMode.KERNEL
63+
)
5064

5165

5266
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/targets/kernel_target.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55

6+
from enum import IntEnum
67
from functools import cached_property
78

89
import dpnp
@@ -30,6 +31,28 @@
3031
LLVM_SPIRV_ARGS = 112
3132

3233

34+
class CompilationMode(IntEnum):
35+
"""Flags used to determine how a function should be compiled by the
36+
numba_dpex.experimental.dispatcher.KernelDispatcher. Note the functionality
37+
will be merged into numba_dpex.core.kernel_interface.dispatcher in the
38+
future.
39+
40+
KERNEL : Indicates that the function will be compiled into an
41+
LLVM function that has ``spir_kernel`` calling
42+
convention and is compiled down to SPIR-V.
43+
Additionally, the function cannot return any value and
44+
input arguments to the function have to adhere to
45+
"compute follows data" to ensure execution queue
46+
inference.
47+
DEVICE_FUNCTION: Indicates that the function will be compiled into an
48+
LLVM function that has ``spir_func`` calling convention
49+
and will be compiled only into LLVM bitcode.
50+
"""
51+
52+
KERNEL = 1
53+
DEVICE_FUNC = 2
54+
55+
3356
class DpexKernelTypingContext(typing.BaseContext):
3457
"""Custom typing context to support kernel compilation.
3558

numba_dpex/core/types/dpctl_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __init__(self, sycl_queue):
2828
self._unique_id = hash(sycl_queue)
2929
except Exception:
3030
self._unique_id = self.rand_digit_str(16)
31-
super(DpctlSyclQueue, self).__init__(name="DpctlSyclQueue")
31+
super(DpctlSyclQueue, self).__init__(
32+
name=f"DpctlSyclQueue on {self._device}"
33+
)
3234

3335
def rand_digit_str(self, n):
3436
return "".join(

numba_dpex/experimental/decorators.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ready to move to numba_dpex.core.
77
"""
88
import inspect
9+
from warnings import warn
910

1011
from numba.core import sigutils
1112
from numba.core.target_extension import (
@@ -14,6 +15,8 @@
1415
target_registry,
1516
)
1617

18+
from numba_dpex.core.targets.kernel_target import CompilationMode
19+
1720
from .target import DPEX_KERNEL_EXP_TARGET_NAME
1821

1922

@@ -30,6 +33,14 @@ def kernel(func_or_sig=None, **options):
3033
"""
3134

3235
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
36+
if "_compilation_mode" in options:
37+
user_compilation_mode = options["_compilation_mode"]
38+
warn(
39+
"_compilation_mode is an internal flag that should not be set "
40+
"in the decorator. The decorator defined option "
41+
f"{user_compilation_mode} is going to be ignored."
42+
)
43+
options["_compilation_mode"] = CompilationMode.KERNEL
3344

3445
# FIXME: The options need to be evaluated and checked here like it is
3546
# done in numba.core.decorators.jit
@@ -80,4 +91,44 @@ def _specialized_kernel_dispatcher(pyfunc):
8091
return _kernel_dispatcher(func)
8192

8293

83-
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = kernel
94+
def device_func(func_or_sig=None, **options):
95+
"""Generates a function with a device-only calling convention, e.g.,
96+
spir_func for SPIR-V based devices.
97+
98+
The decorator is used to compile overloads in the DpexKernelTarget and
99+
users should use the decorator to define functions that are only callable
100+
from inside another device_func or a kernel.
101+
102+
A device_func is not compiled down to device binary IR and instead left as
103+
LLVM IR. It is done so that the function can be inlined fully into the
104+
kernel module from where it is used at the LLVM level, leading to more
105+
optimization opportunities.
106+
107+
Returns:
108+
KernelDispatcher: A KernelDispatcher instance with the
109+
_compilation_mode option set to DEVICE_FUNC.
110+
"""
111+
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
112+
113+
if "_compilation_mode" in options:
114+
user_compilation_mode = options["_compilation_mode"]
115+
warn(
116+
"_compilation_mode is an internal flag that should not be set "
117+
"in the decorator. The decorator defined option "
118+
f"{user_compilation_mode} is going to be ignored."
119+
)
120+
options["_compilation_mode"] = CompilationMode.DEVICE_FUNC
121+
122+
def _kernel_dispatcher(pyfunc):
123+
return dispatcher(
124+
pyfunc=pyfunc,
125+
targetoptions=options,
126+
)
127+
128+
if func_or_sig is None:
129+
return _kernel_dispatcher
130+
131+
return _kernel_dispatcher(func_or_sig)
132+
133+
134+
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = device_func

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111

1212
import numba.core.event as ev
1313
from numba.core import errors, sigutils, types
14-
from numba.core.compiler import CompileResult
14+
from numba.core.compiler import CompileResult, Flags
1515
from numba.core.compiler_lock import global_compiler_lock
1616
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
1717
from numba.core.target_extension import dispatcher_registry, target_registry
18+
from numba.core.types import void
1819
from numba.core.typing.typeof import Purpose, typeof
1920

2021
from numba_dpex import config, spirv_generator
2122
from numba_dpex.core.exceptions import (
2223
ExecutionQueueInferenceError,
24+
KernelHasReturnValueError,
2325
UnsupportedKernelArgumentError,
2426
)
2527
from numba_dpex.core.pipelines import kernel_compiler
28+
from numba_dpex.core.targets.kernel_target import CompilationMode
2629
from numba_dpex.core.types import DpnpNdArray
2730

2831
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
@@ -81,10 +84,19 @@ def _compile_to_spirv(
8184
kernel_fn = kernel_targetctx.prepare_spir_kernel(
8285
kernel_func, kernel_fndesc.argtypes
8386
)
84-
85-
# makes sure that the spir_func is completely inlined into the
86-
# spir_kernel wrapper
87-
kernel_library.optimize_final_module()
87+
# Get the compiler flags that were passed through the target descriptor
88+
flags = Flags()
89+
self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
90+
91+
# If the inline_threshold option was set then set the property in the
92+
# kernel_library to force inlining ``overload`` calls into a kernel.
93+
inline_threshold = flags.inline_threshold # pylint: disable=E1101
94+
kernel_library.inline_threshold = inline_threshold
95+
96+
# Call finalize on the LLVM module. Finalization will result in
97+
# all linking libraries getting linked together and final optimization
98+
# including inlining of functions if an inlining level is specified.
99+
kernel_library.finalize()
88100
# Compiled the LLVM IR to SPIR-V
89101
kernel_spirv_module = spirv_generator.llvm_to_spirv(
90102
kernel_targetctx,
@@ -144,9 +156,15 @@ def _compile_cached(
144156
try:
145157
cres: CompileResult = self._compile_core(args, return_type)
146158

147-
kernel_device_ir_module = self._compile_to_spirv(
148-
cres.library, cres.fndesc, cres.target_context
149-
)
159+
if (
160+
self.targetoptions["_compilation_mode"]
161+
== CompilationMode.KERNEL
162+
):
163+
kernel_device_ir_module: _KernelModule = self._compile_to_spirv(
164+
cres.library, cres.fndesc, cres.target_context
165+
)
166+
else:
167+
kernel_device_ir_module = None
150168

151169
kcres_attrs = []
152170

@@ -185,9 +203,6 @@ class KernelDispatcher(Dispatcher):
185203
an executable binary, the dispatcher compiles it to SPIR-V and then caches
186204
that SPIR-V bitcode.
187205
188-
FIXME: Fix issues identified by pylint with this class.
189-
https://github.com/IntelPython/numba-dpex/issues/1196
190-
191206
"""
192207

193208
targetdescr = dpex_exp_kernel_target
@@ -282,12 +297,28 @@ def cb_llvm(dur):
282297
with self._compiling_counter:
283298
args, return_type = sigutils.normalize_signature(sig)
284299

285-
try:
286-
self._compiler.check_queue_equivalence_of_args(
287-
self._kernel_name, args
288-
)
289-
except ExecutionQueueInferenceError as eqie:
290-
raise eqie
300+
if (
301+
self.targetoptions["_compilation_mode"]
302+
== CompilationMode.KERNEL
303+
):
304+
# Compute follows data based queue equivalence is only
305+
# evaluated for kernel functions whose arguments are
306+
# supposed to be arrays. For device_func decorated
307+
# functions, the arguments can be scalar and we skip queue
308+
# equivalence check.
309+
try:
310+
self._compiler.check_queue_equivalence_of_args(
311+
self._kernel_name, args
312+
)
313+
except ExecutionQueueInferenceError as eqie:
314+
raise eqie
315+
316+
# A function being compiled in the KERNEL compilation mode
317+
# cannot have a non-void return value
318+
if return_type and return_type != void:
319+
raise KernelHasReturnValueError(
320+
kernel_name=None, return_type=return_type, sig=sig
321+
)
291322

292323
# Don't recompile if signature already exists
293324
existing = self.overloads.get(tuple(args))
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)