Skip to content

Commit 0af4414

Browse files
author
Diptorup Deb
committed
Unit test to check if inline_threshold works
1 parent 38edc59 commit 0af4414

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpctl
6+
from numba.core import types
7+
8+
import numba_dpex as dpex
9+
from numba_dpex import DpctlSyclQueue, DpnpNdArray
10+
from numba_dpex import experimental as dpex_exp
11+
from numba_dpex import int64
12+
13+
14+
def kernel_func(a, b, c):
15+
i = dpex.get_global_id(0)
16+
c[i] = a[i] + b[i]
17+
18+
19+
def test_codegen_with_max_inline_threshold():
20+
"""Tests if the inline_threshold option leads to a fully inlined kernel
21+
function generation.
22+
23+
By default, numba_dpex compiles a function passed to the `kernel` decorator
24+
into a `spir_func` LLVM function. Then before lowering to device IR, the
25+
DpexTargetContext creates a "wrapper" function that has the "spir_kernel"
26+
calling convention. It is done so that we can use the same target context
27+
and pipeline to compile both host callable "kernels" and device-only
28+
"device_func" functions.
29+
30+
Unless the inline_threshold is set to 3, the `spir_func` function is not
31+
inlined into the wrapper function. The test checks if the `spir_func`
32+
function is fully inlined into the wrapper. The test is rather rudimentary
33+
and only checks the count of function in the generated module.
34+
With inlining, the count should be one and without inlining it will be two.
35+
"""
36+
37+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
38+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
39+
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
40+
41+
disp = dpex_exp.kernel(inline_threshold=3)(kernel_func)
42+
disp.compile(kernel_sig)
43+
kcres = disp.overloads[kernel_sig.args]
44+
llvm_ir_mod = kcres.library._final_module
45+
46+
count_of_non_declaration_type_functions = 0
47+
48+
for f in llvm_ir_mod.functions:
49+
if not f.is_declaration:
50+
count_of_non_declaration_type_functions += 1
51+
52+
assert count_of_non_declaration_type_functions == 1
53+
54+
55+
def test_codegen_without_max_inline_threshold():
56+
"""See docstring of :func:`test_codegen_with_max_inline_threshold`."""
57+
58+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
59+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
60+
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
61+
62+
disp = dpex_exp.kernel(kernel_func)
63+
disp.compile(kernel_sig)
64+
kcres = disp.overloads[kernel_sig.args]
65+
llvm_ir_mod = kcres.library._final_module
66+
67+
count_of_non_declaration_type_functions = 0
68+
69+
for f in llvm_ir_mod.functions:
70+
if not f.is_declaration:
71+
count_of_non_declaration_type_functions += 1
72+
73+
assert count_of_non_declaration_type_functions == 2

0 commit comments

Comments
 (0)