Skip to content

Commit 83806ee

Browse files
author
Diptorup Deb
committed
Temporary driver for testing
1 parent b00e75d commit 83806ee

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

driver.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import dpnp
2+
from numba.core import types
3+
from numba.core.extending import intrinsic, overload
4+
5+
import numba_dpex as dpex
6+
import numba_dpex.experimental as dpex_exp
7+
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
8+
9+
10+
# pylint: disable=W0613
11+
def scalar_add(a, b):
12+
return a + b
13+
14+
15+
@overload(scalar_add, target=DPEX_KERNEL_EXP_TARGET_NAME)
16+
def _ol_scalar_add(a, b):
17+
def ol_scalar_add_impl(a, b):
18+
return a + b
19+
20+
return ol_scalar_add_impl
21+
22+
23+
@dpex_exp.kernel
24+
def test_overload_call(a, b, c):
25+
i = dpex.get_global_id(0)
26+
c[i] = scalar_add(a[i], b[i])
27+
28+
29+
a = dpnp.ones(10, dtype=dpnp.int64)
30+
b = dpnp.ones(10, dtype=dpnp.int64)
31+
c = dpnp.zeros(10, dtype=dpnp.int64)
32+
33+
dpex_exp.call_kernel(test_overload_call, dpex.Range(10), a, b, c)
34+
35+
print(c)

0 commit comments

Comments
 (0)