Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _kernel_dispatcher(pyfunc):
"Argument passed to the kernel decorator is neither a "
"function object, nor a signature. If you are trying to "
"specialize the kernel that takes a single argument, specify "
"the return type as void explicitly."
"the return type as None explicitly."
)
return _kernel_dispatcher(func)

Expand Down Expand Up @@ -132,13 +132,28 @@ def device_func(func_or_sig=None, **options):
)
options["_compilation_mode"] = CompilationMode.DEVICE_FUNC

func, sigs = _parse_func_or_sig(func_or_sig)
for sig in sigs:
if isinstance(sig, str):
raise NotImplementedError(
"Specifying signatures as string is not yet supported"
)

def _kernel_dispatcher(pyfunc):
return dispatcher(
disp: SPIRVKernelDispatcher = dispatcher(
pyfunc=pyfunc,
targetoptions=options,
)

if func_or_sig is None:
if len(sigs) > 0:
with typeinfer.register_dispatcher(disp):
for sig in sigs:
disp.compile(sig)
disp.disable_compile()

return disp

if func is None:
return _kernel_dispatcher

return _kernel_dispatcher(func_or_sig)
Expand Down
139 changes: 61 additions & 78 deletions numba_dpex/tests/kernel_tests/test_func_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,105 +4,88 @@

import dpnp
import numpy as np
import pytest
from numba import int32, int64

import numba_dpex as dpex
from numba_dpex import float32, int32
import numba_dpex.experimental as dpex

single_signature = dpex.func(int32(int32))
list_signature = dpex.func([int32(int32), float32(float32)])
i32_signature = dpex.device_func(int32(int32))
i32i64_signature = dpex.device_func([int32(int32), int64(int64)])

# Array size
N = 10
N = 1024


def increment(a):
return a + dpnp.float32(1)
return a + 1


def test_basic():
"""Basic test with device func"""
fi32 = i32_signature(increment)
fi32i64 = i32i64_signature(increment)

f = dpex.func(increment)

def kernel_function(a, b):
"""Kernel function that applies f() in parallel"""
i = dpex.get_global_id(0)
b[i] = f(a[i])
@dpex.kernel
def kernel_function(item, a, b):
"""Kernel function that calls fi32()"""
i = item.get_id(0)
b[i] = fi32(a[i])

k = dpex.kernel(kernel_function)

a = dpnp.ones(N)
b = dpnp.ones(N)
@dpex.kernel
def kernel_function2(item, a, b):
"""Kernel function that calls fi32i64()"""
i = item.get_id(0)
b[i] = fi32i64(a[i])

dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)


def test_single_signature():
"""Basic test with single signature"""

fi32 = single_signature(increment)

def kernel_function(a, b):
"""Kernel function that applies fi32() in parallel"""
i = dpex.get_global_id(0)
b[i] = fi32(a[i])

k = dpex.kernel(kernel_function)

# Test with int32, should work
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)

dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

# Test with int64, should fail
a = dpnp.ones(N, dtype=dpnp.int64)
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]


def test_list_signature():
"""Basic test with list signature"""

fi32f32 = list_signature(increment)

def kernel_function(a, b):
"""Kernel function that applies fi32f32() in parallel"""
i = dpex.get_global_id(0)
b[i] = fi32f32(a[i])

k = dpex.kernel(kernel_function)

# Test with int32, should work
def test_calling_specialized_device_func():
"""Tests if a specialized device_func gets called as expected from kernel"""
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.zeros(N, dtype=dpnp.int32)

dpex.call_kernel(k, dpex.Range(N), a, b)
dpex.call_kernel(kernel_function, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

# Test with float32, should work
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.ones(N, dtype=dpnp.float32)

dpex.call_kernel(k, dpex.Range(N), a, b)
def test_calling_specialized_device_func_wrong_signature():
"""Tests that calling specialized signature with wrong signature does not
trigger recompilation.

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
Tests kernel_function with float32. Numba will downcast float32 to int32
and call the specialized function. The implicit casting is a problem, but
for the purpose of this test case, all we care is to check if the
specialized function was called and we did not recompiled the device_func.
Refer: https://github.com/numba/numba/issues/9506

"""
# Test with int64, should fail
a = dpnp.ones(N, dtype=dpnp.int64)
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]
a = dpnp.full(N, 1.5, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.float32)

dpex.call_kernel(kernel_function, dpex.Range(N), a, b)

# Since Numba is calling the i32 specialization of increment, the values in
# `a` are first down converted to int32, *i.e.*, 1.5 to 1 and then
# incremented. Thus, the output is 2 instead of 2.5.
# The implicit down casting is a dangerous thing for Numba to do, but we use
# to our advantage to test if re compilation did not happen for a
# specialized device function.
assert np.all(dpnp.asnumpy(b) == 2)
assert not np.all(dpnp.asnumpy(b) == 2.5)


def test_multi_specialized_device_func():
"""Tests if a device_func with multiple specialization can be called
in a kernel
"""
# Test with int32, i64 should work
ai32 = dpnp.ones(N, dtype=dpnp.int32)
bi32 = dpnp.ones(N, dtype=dpnp.int32)
ai64 = dpnp.ones(N, dtype=dpnp.int64)
bi64 = dpnp.ones(N, dtype=dpnp.int64)

dpex.call_kernel(kernel_function2, dpex.Range(N), ai32, bi32)
dpex.call_kernel(kernel_function2, dpex.Range(N), ai64, bi64)

assert np.array_equal(dpnp.asnumpy(bi32), dpnp.asnumpy(ai32) + 1)
assert np.array_equal(dpnp.asnumpy(bi64), dpnp.asnumpy(ai64) + 1)