Skip to content

Commit cf04418

Browse files
author
Diptorup Deb
committed
Add unit test for overload compilation.
1 parent 89f1a1c commit cf04418

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
from numba.core.extending import overload
7+
8+
import numba_dpex as dpex
9+
import numba_dpex.experimental as dpex_exp
10+
from numba_dpex.core.descriptor import dpex_kernel_target
11+
from numba_dpex.experimental.target import (
12+
DPEX_KERNEL_EXP_TARGET_NAME,
13+
dpex_exp_kernel_target,
14+
)
15+
16+
17+
def scalar_add(a, b):
18+
return a + b
19+
20+
21+
@overload(scalar_add, target=DPEX_KERNEL_EXP_TARGET_NAME)
22+
def _ol_scalar_add(a, b):
23+
def ol_scalar_add_impl(a, b):
24+
return a + b
25+
26+
return ol_scalar_add_impl
27+
28+
29+
@dpex_exp.kernel
30+
def kernel_calling_overload(a, b, c):
31+
i = dpex.get_global_id(0)
32+
c[i] = scalar_add(a[i], b[i])
33+
34+
35+
a = dpnp.ones(10, dtype=dpnp.int64)
36+
b = dpnp.ones(10, dtype=dpnp.int64)
37+
c = dpnp.zeros(10, dtype=dpnp.int64)
38+
39+
dpex_exp.call_kernel(kernel_calling_overload, dpex.Range(10), a, b, c)
40+
41+
42+
def test_end_to_end_overload_execution():
43+
"""Tests that an overload function can be called from an experimental.kernel
44+
decorated function and works end to end.
45+
"""
46+
for i in range(c.shape[0]):
47+
assert c[i] == scalar_add(a[i], b[i])
48+
49+
50+
def test_overload_registration():
51+
"""Tests that the overload _ol_scalar_add is registered only in the
52+
"dpex_kernel_exp" target and not in the "dpex_kernel" target.
53+
"""
54+
55+
def check_for_overload_registration(targetctx, key):
56+
found_key = False
57+
for fn_key in targetctx._defns.keys():
58+
if isinstance(fn_key, str) and fn_key.startswith(key):
59+
found_key = True
60+
break
61+
return found_key
62+
63+
assert check_for_overload_registration(
64+
dpex_exp_kernel_target.target_context, "_ol_scalar_add"
65+
)
66+
assert not check_for_overload_registration(
67+
dpex_kernel_target.target_context, "_ol_scalar_add"
68+
)

0 commit comments

Comments
 (0)