Skip to content

Commit fa3cd55

Browse files
Update dpnp.power using dpctl and OneMKL implementations (#1476)
* Reuse dpctl.tensor.pow for dpnp.power * Add pow call from OneMKL by pybind11 extension * Update all tests for dpnp.power * Update examples for dpnp.power * Update dpnp_power and use OneMKL only on Linux for it * Restore deleted funcs in test_arithmetic * Remove dpnp_init_val * Skip test_copy --------- Co-authored-by: Anton <[email protected]>
1 parent b694d87 commit fa3cd55

17 files changed

+388
-732
lines changed

dpnp/backend/extensions/vm/pow.hpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event pow_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::pow(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct PowContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::PowOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return pow_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,31 @@ struct MulOutputType
203203
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
204204
};
205205

206+
/**
207+
* @brief A factory to define pairs of supported types for which
208+
* MKL VM library provides support in oneapi::mkl::vm::pow<T> function.
209+
*
210+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
211+
*/
212+
template <typename T>
213+
struct PowOutputType
214+
{
215+
using value_type = typename std::disjunction<
216+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
217+
std::complex<double>,
218+
T,
219+
std::complex<double>,
220+
std::complex<double>>,
221+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
222+
std::complex<float>,
223+
T,
224+
std::complex<float>,
225+
std::complex<float>>,
226+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
227+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
228+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
229+
};
230+
206231
/**
207232
* @brief A factory to define pairs of supported types for which
208233
* MKL VM library provides support in oneapi::mkl::vm::rint<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "floor.hpp"
4040
#include "ln.hpp"
4141
#include "mul.hpp"
42+
#include "pow.hpp"
4243
#include "round.hpp"
4344
#include "sin.hpp"
4445
#include "sqr.hpp"
@@ -61,6 +62,7 @@ static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
6162
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
6263
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
6364
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
65+
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
6466
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];
6567
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
6668
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
@@ -303,6 +305,36 @@ PYBIND11_MODULE(_vm_impl, m)
303305
py::arg("dst"));
304306
}
305307

308+
// BinaryUfunc: ==== Pow(x1, x2) ====
309+
{
310+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
311+
vm_ext::PowContigFactory>(
312+
pow_dispatch_vector);
313+
314+
auto pow_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
315+
arrayT dst, const event_vecT &depends = {}) {
316+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
317+
pow_dispatch_vector);
318+
};
319+
m.def("_pow", pow_pyapi,
320+
"Call `pow` function from OneMKL VM library to performs element "
321+
"by element exponentiation of vector `src1` raised to the power "
322+
"of vector `src2` to resulting vector `dst`",
323+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
324+
py::arg("dst"), py::arg("depends") = py::list());
325+
326+
auto pow_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
327+
arrayT src2, arrayT dst) {
328+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
329+
pow_dispatch_vector);
330+
};
331+
m.def("_mkl_pow_to_call", pow_need_to_call_pyapi,
332+
"Check input arguments to answer if `pow` function from "
333+
"OneMKL VM library can be used",
334+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
335+
py::arg("dst"));
336+
}
337+
306338
// UnaryUfunc: ==== Round(x) ====
307339
{
308340
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,6 @@ enum class DPNPFuncName : size_t
290290
parameters */
291291
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
292292
DPNP_FN_POWER, /**< Used in numpy.power() impl */
293-
DPNP_FN_POWER_EXT, /**< Used in numpy.power() impl, requires extra
294-
parameters */
295293
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
296294
DPNP_FN_PROD_EXT, /**< Used in numpy.prod() impl, requires extra parameters
297295
*/

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -1565,13 +1565,6 @@ static void func_map_elemwise_2arg_3type_core(func_map_t &fmap)
15651565
func_type_map_t::find_type<FT1>,
15661566
func_type_map_t::find_type<FTs>>}),
15671567
...);
1568-
((fmap[DPNPFuncName::DPNP_FN_POWER_EXT][FT1][FTs] =
1569-
{populate_func_types<FT1, FTs>(),
1570-
(void *)dpnp_power_c_ext<
1571-
func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
1572-
func_type_map_t::find_type<FT1>,
1573-
func_type_map_t::find_type<FTs>>}),
1574-
...);
15751568
((fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][FT1][FTs] =
15761569
{populate_func_types<FT1, FTs>(),
15771570
(void *)dpnp_subtract_c_ext<

dpnp/dpnp_algo/dpnp_algo.pxd

-7
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
128128
DPNP_FN_HYPOT_EXT
129129
DPNP_FN_IDENTITY
130130
DPNP_FN_IDENTITY_EXT
131-
DPNP_FN_INITVAL
132-
DPNP_FN_INITVAL_EXT
133131
DPNP_FN_INV
134132
DPNP_FN_INV_EXT
135133
DPNP_FN_KRON
@@ -164,8 +162,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
164162
DPNP_FN_PARTITION
165163
DPNP_FN_PARTITION_EXT
166164
DPNP_FN_PLACE
167-
DPNP_FN_POWER
168-
DPNP_FN_POWER_EXT
169165
DPNP_FN_PROD
170166
DPNP_FN_PROD_EXT
171167
DPNP_FN_PTP
@@ -407,7 +403,6 @@ cpdef dpnp_descriptor dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_
407403
"""
408404
Array creation routines
409405
"""
410-
cpdef dpnp_descriptor dpnp_init_val(shape, dtype, value)
411406
cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
412407

413408
"""
@@ -421,8 +416,6 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
421416
dpnp_descriptor out=*, object where=*)
422417
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
423418
dpnp_descriptor out=*, object where=*)
424-
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
425-
dpnp_descriptor out=*, object where=*)
426419

427420
"""
428421
Array manipulation routines

dpnp/dpnp_algo/dpnp_algo.pyx

-37
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ import numpy
5656
__all__ = [
5757
"dpnp_astype",
5858
"dpnp_flatten",
59-
"dpnp_init_val",
6059
"dpnp_queue_initialize",
6160
]
6261

@@ -85,9 +84,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_flatten_t)(c_dpctl.DPCTLSyclQueueR
8584
const shape_elem_type * , const shape_elem_type * ,
8685
const long * ,
8786
const c_dpctl.DPCTLEventVectorRef)
88-
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_initval_t)(c_dpctl.DPCTLSyclQueueRef,
89-
void *, void * , size_t,
90-
const c_dpctl.DPCTLEventVectorRef)
9187

9288

9389
cpdef utils.dpnp_descriptor dpnp_astype(utils.dpnp_descriptor x1, dtype):
@@ -168,39 +164,6 @@ cpdef utils.dpnp_descriptor dpnp_flatten(utils.dpnp_descriptor x1):
168164
return result
169165

170166

171-
cpdef utils.dpnp_descriptor dpnp_init_val(shape, dtype, value):
172-
"""
173-
same as dpnp_full(). TODO remove code duplication
174-
"""
175-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
176-
177-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_INITVAL_EXT, param1_type, param1_type)
178-
179-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(shape, dtype, None)
180-
181-
result_obj = result.get_array()
182-
183-
# TODO: find better way to pass single value with type conversion
184-
cdef utils.dpnp_descriptor val_arr = utils_py.create_output_descriptor_py((1, ),
185-
dtype,
186-
None,
187-
device=result_obj.sycl_device,
188-
usm_type=result_obj.usm_type,
189-
sycl_queue=result_obj.sycl_queue)
190-
val_arr.get_pyobj()[0] = value
191-
192-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_obj.sycl_queue
193-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
194-
195-
cdef fptr_dpnp_initval_t func = <fptr_dpnp_initval_t > kernel_data.ptr
196-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, result.get_data(), val_arr.get_data(), result.size, NULL)
197-
198-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
199-
c_dpctl.DPCTLEvent_Delete(event_ref)
200-
201-
return result
202-
203-
204167
cpdef dpnp_queue_initialize():
205168
"""
206169
Initialize SYCL queue which will be used for any library operations.

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

-9
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ __all__ += [
5555
"dpnp_nancumsum",
5656
"dpnp_nanprod",
5757
"dpnp_nansum",
58-
"dpnp_power",
5958
"dpnp_prod",
6059
"dpnp_sum",
6160
"dpnp_trapz",
@@ -417,14 +416,6 @@ cpdef utils.dpnp_descriptor dpnp_nansum(utils.dpnp_descriptor x1):
417416
return dpnp_sum(result)
418417

419418

420-
cpdef utils.dpnp_descriptor dpnp_power(utils.dpnp_descriptor x1_obj,
421-
utils.dpnp_descriptor x2_obj,
422-
object dtype=None,
423-
utils.dpnp_descriptor out=None,
424-
object where=True):
425-
return call_fptr_2in_1out_strides(DPNP_FN_POWER_EXT, x1_obj, x2_obj, dtype, out, where, func_name="power")
426-
427-
428419
cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
429420
object axis=None,
430421
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

+66
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
# *****************************************************************************
2828

2929

30+
from sys import platform
31+
3032
import dpctl.tensor._tensor_impl as ti
3133
from dpctl.tensor._elementwise_common import (
3234
BinaryElementwiseFunc,
@@ -68,6 +70,7 @@
6870
"dpnp_multiply",
6971
"dpnp_negative",
7072
"dpnp_not_equal",
73+
"dpnp_power",
7174
"dpnp_proj",
7275
"dpnp_remainder",
7376
"dpnp_right_shift",
@@ -1460,6 +1463,69 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
14601463
return dpnp_array._create_from_usm_ndarray(res_usm)
14611464

14621465

1466+
_power_docstring_ = """
1467+
power(x1, x2, out=None, order="K")
1468+
1469+
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
1470+
`x1` with the respective element `x2_i` of the input array `x2`.
1471+
1472+
Args:
1473+
x1 (dpnp.ndarray):
1474+
First input array, expected to have numeric data type.
1475+
x2 (dpnp.ndarray):
1476+
Second input array, also expected to have numeric data type.
1477+
out ({None, dpnp.ndarray}, optional):
1478+
Output array to populate. Array must have the correct
1479+
shape and the expected data type.
1480+
order ("C","F","A","K", None, optional):
1481+
Output array, if parameter `out` is `None`.
1482+
Default: "K".
1483+
Returns:
1484+
dpnp.ndarray:
1485+
An array containing the result of element-wise of raising each element
1486+
to a specified power.
1487+
The data type of the returned array is determined by the Type Promotion Rules.
1488+
"""
1489+
1490+
1491+
def _call_pow(src1, src2, dst, sycl_queue, depends=None):
1492+
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""
1493+
1494+
if depends is None:
1495+
depends = []
1496+
1497+
# TODO: remove this check when OneMKL is fixed on Windows
1498+
is_win = platform.startswith("win")
1499+
1500+
if not is_win and vmi._mkl_pow_to_call(sycl_queue, src1, src2, dst):
1501+
# call pybind11 extension for pow() function from OneMKL VM
1502+
return vmi._pow(sycl_queue, src1, src2, dst, depends)
1503+
return ti._pow(src1, src2, dst, sycl_queue, depends)
1504+
1505+
1506+
pow_func = BinaryElementwiseFunc(
1507+
"pow", ti._pow_result_type, _call_pow, _power_docstring_
1508+
)
1509+
1510+
1511+
def dpnp_power(x1, x2, out=None, order="K"):
1512+
"""
1513+
Invokes pow() function from pybind11 extension of OneMKL VM if possible.
1514+
1515+
Otherwise fully relies on dpctl.tensor implementation for pow() function.
1516+
"""
1517+
1518+
# dpctl.tensor only works with usm_ndarray or scalar
1519+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
1520+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
1521+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1522+
1523+
res_usm = pow_func(
1524+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
1525+
)
1526+
return dpnp_array._create_from_usm_ndarray(res_usm)
1527+
1528+
14631529
_proj_docstring = """
14641530
proj(x, out=None, order="K")
14651531

0 commit comments

Comments
 (0)