diff --git a/numba_dpex/experimental/decorators.py b/numba_dpex/experimental/decorators.py index a7529cea24..89ff2d7ff4 100644 --- a/numba_dpex/experimental/decorators.py +++ b/numba_dpex/experimental/decorators.py @@ -99,7 +99,7 @@ def _kernel_dispatcher(pyfunc): "Argument passed to the kernel decorator is neither a " "function object, nor a signature. If you are trying to " "specialize the kernel that takes a single argument, specify " - "the return type as void explicitly." + "the return type as None explicitly." ) return _kernel_dispatcher(func) @@ -132,13 +132,28 @@ def device_func(func_or_sig=None, **options): ) options["_compilation_mode"] = CompilationMode.DEVICE_FUNC + func, sigs = _parse_func_or_sig(func_or_sig) + for sig in sigs: + if isinstance(sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported" + ) + def _kernel_dispatcher(pyfunc): - return dispatcher( + disp: SPIRVKernelDispatcher = dispatcher( pyfunc=pyfunc, targetoptions=options, ) - if func_or_sig is None: + if len(sigs) > 0: + with typeinfer.register_dispatcher(disp): + for sig in sigs: + disp.compile(sig) + disp.disable_compile() + + return disp + + if func is None: return _kernel_dispatcher return _kernel_dispatcher(func_or_sig) diff --git a/numba_dpex/tests/kernel_tests/test_func_specialization.py b/numba_dpex/tests/kernel_tests/test_func_specialization.py index e44af1b7c3..223c9c0521 100644 --- a/numba_dpex/tests/kernel_tests/test_func_specialization.py +++ b/numba_dpex/tests/kernel_tests/test_func_specialization.py @@ -4,105 +4,88 @@ import dpnp import numpy as np -import pytest +from numba import int32, int64 -import numba_dpex as dpex -from numba_dpex import float32, int32 +import numba_dpex.experimental as dpex -single_signature = dpex.func(int32(int32)) -list_signature = dpex.func([int32(int32), float32(float32)]) +i32_signature = dpex.device_func(int32(int32)) +i32i64_signature = dpex.device_func([int32(int32), int64(int64)]) # Array size -N = 10 +N = 1024 def increment(a): - return a + dpnp.float32(1) + return a + 1 -def test_basic(): - """Basic test with device func""" +fi32 = i32_signature(increment) +fi32i64 = i32i64_signature(increment) - f = dpex.func(increment) - def kernel_function(a, b): - """Kernel function that applies f() in parallel""" - i = dpex.get_global_id(0) - b[i] = f(a[i]) +@dpex.kernel +def kernel_function(item, a, b): + """Kernel function that calls fi32()""" + i = item.get_id(0) + b[i] = fi32(a[i]) - k = dpex.kernel(kernel_function) - a = dpnp.ones(N) - b = dpnp.ones(N) +@dpex.kernel +def kernel_function2(item, a, b): + """Kernel function that calls fi32i64()""" + i = item.get_id(0) + b[i] = fi32i64(a[i]) - dpex.call_kernel(k, dpex.Range(N), a, b) - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) - - -def test_single_signature(): - """Basic test with single signature""" - - fi32 = single_signature(increment) - - def kernel_function(a, b): - """Kernel function that applies fi32() in parallel""" - i = dpex.get_global_id(0) - b[i] = fi32(a[i]) - - k = dpex.kernel(kernel_function) - - # Test with int32, should work - a = dpnp.ones(N, dtype=dpnp.int32) - b = dpnp.ones(N, dtype=dpnp.int32) - - dpex.call_kernel(k, dpex.Range(N), a, b) - - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) - - # Test with int64, should fail - a = dpnp.ones(N, dtype=dpnp.int64) - b = dpnp.ones(N, dtype=dpnp.int64) - - with pytest.raises(Exception) as e: - dpex.call_kernel(k, dpex.Range(N), a, b) - - assert " >>> (int64)" in e.value.args[0] - - -def test_list_signature(): - """Basic test with list signature""" - - fi32f32 = list_signature(increment) - - def kernel_function(a, b): - """Kernel function that applies fi32f32() in parallel""" - i = dpex.get_global_id(0) - b[i] = fi32f32(a[i]) - - k = dpex.kernel(kernel_function) - - # Test with int32, should work +def test_calling_specialized_device_func(): + """Tests if a specialized device_func gets called as expected from kernel""" a = dpnp.ones(N, dtype=dpnp.int32) - b = dpnp.ones(N, dtype=dpnp.int32) + b = dpnp.zeros(N, dtype=dpnp.int32) - dpex.call_kernel(k, dpex.Range(N), a, b) + dpex.call_kernel(kernel_function, dpex.Range(N), a, b) assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) - # Test with float32, should work - a = dpnp.ones(N, dtype=dpnp.float32) - b = dpnp.ones(N, dtype=dpnp.float32) - dpex.call_kernel(k, dpex.Range(N), a, b) +def test_calling_specialized_device_func_wrong_signature(): + """Tests that calling specialized signature with wrong signature does not + trigger recompilation. - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) + Tests kernel_function with float32. Numba will downcast float32 to int32 + and call the specialized function. The implicit casting is a problem, but + for the purpose of this test case, all we care is to check if the + specialized function was called and we did not recompiled the device_func. + Refer: https://github.com/numba/numba/issues/9506 + """ # Test with int64, should fail - a = dpnp.ones(N, dtype=dpnp.int64) - b = dpnp.ones(N, dtype=dpnp.int64) - - with pytest.raises(Exception) as e: - dpex.call_kernel(k, dpex.Range(N), a, b) - - assert " >>> (int64)" in e.value.args[0] + a = dpnp.full(N, 1.5, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.float32) + + dpex.call_kernel(kernel_function, dpex.Range(N), a, b) + + # Since Numba is calling the i32 specialization of increment, the values in + # `a` are first down converted to int32, *i.e.*, 1.5 to 1 and then + # incremented. Thus, the output is 2 instead of 2.5. + # The implicit down casting is a dangerous thing for Numba to do, but we use + # to our advantage to test if re compilation did not happen for a + # specialized device function. + assert np.all(dpnp.asnumpy(b) == 2) + assert not np.all(dpnp.asnumpy(b) == 2.5) + + +def test_multi_specialized_device_func(): + """Tests if a device_func with multiple specialization can be called + in a kernel + """ + # Test with int32, i64 should work + ai32 = dpnp.ones(N, dtype=dpnp.int32) + bi32 = dpnp.ones(N, dtype=dpnp.int32) + ai64 = dpnp.ones(N, dtype=dpnp.int64) + bi64 = dpnp.ones(N, dtype=dpnp.int64) + + dpex.call_kernel(kernel_function2, dpex.Range(N), ai32, bi32) + dpex.call_kernel(kernel_function2, dpex.Range(N), ai64, bi64) + + assert np.array_equal(dpnp.asnumpy(bi32), dpnp.asnumpy(ai32) + 1) + assert np.array_equal(dpnp.asnumpy(bi64), dpnp.asnumpy(ai64) + 1)