Skip to content

Commit ab97cf3

Browse files
authored
Merge pull request #1519 from IntelPython/use_dpctl_conj_in_dpnp
use_dpctl_conj_for_dpnp
2 parents 36f26aa + 653ce2e commit ab97cf3

15 files changed

+256
-57
lines changed

dpnp/backend/extensions/vm/conj.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event conj_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::conj(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct ConjContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::ConjOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return conj_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ struct CeilOutputType
8383
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
8484
};
8585

86+
/**
87+
* @brief A factory to define pairs of supported types for which
88+
* MKL VM library provides support in oneapi::mkl::vm::conj<T> function.
89+
*
90+
* @tparam T Type of input vector `a` and of result vector `y`.
91+
*/
92+
template <typename T>
93+
struct ConjOutputType
94+
{
95+
using value_type = typename std::disjunction<
96+
dpctl_td_ns::
97+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
98+
dpctl_td_ns::
99+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
100+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
101+
};
102+
86103
/**
87104
* @brief A factory to define pairs of supported types for which
88105
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "add.hpp"
3434
#include "ceil.hpp"
3535
#include "common.hpp"
36+
#include "conj.hpp"
3637
#include "cos.hpp"
3738
#include "div.hpp"
3839
#include "floor.hpp"
@@ -56,6 +57,7 @@ static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types];
5657
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5758
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
5859
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
60+
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
5961
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
6062
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
6163
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
@@ -127,6 +129,34 @@ PYBIND11_MODULE(_vm_impl, m)
127129
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
128130
}
129131

132+
// UnaryUfunc: ==== Conj(x) ====
133+
{
134+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
135+
vm_ext::ConjContigFactory>(
136+
conj_dispatch_vector);
137+
138+
auto conj_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
139+
const event_vecT &depends = {}) {
140+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
141+
conj_dispatch_vector);
142+
};
143+
m.def("_conj", conj_pyapi,
144+
"Call `conj` function from OneMKL VM library to compute "
145+
"conjugate of vector elements",
146+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
147+
py::arg("depends") = py::list());
148+
149+
auto conj_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
150+
arrayT dst) {
151+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
152+
conj_dispatch_vector);
153+
};
154+
m.def("_mkl_conj_to_call", conj_need_to_call_pyapi,
155+
"Check input arguments to answer if `conj` function from "
156+
"OneMKL VM library can be used",
157+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
158+
}
159+
130160
// UnaryUfunc: ==== Cos(x) ====
131161
{
132162
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,13 @@ enum class DPNPFuncName : size_t
114114
*/
115115
DPNP_FN_CEIL, /**< Used in numpy.ceil() impl */
116116
DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() impl */
117-
DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires
118-
extra parameters */
119-
DPNP_FN_CONJIGUATE, /**< Used in numpy.conjugate() impl */
120-
DPNP_FN_CONJIGUATE_EXT, /**< Used in numpy.conjugate() impl, requires extra
121-
parameters */
122-
DPNP_FN_CHOOSE, /**< Used in numpy.choose() impl */
123-
DPNP_FN_CHOOSE_EXT, /**< Used in numpy.choose() impl, requires extra
124-
parameters */
125-
DPNP_FN_COPY, /**< Used in numpy.copy() impl */
117+
DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires
118+
extra parameters */
119+
DPNP_FN_CONJUGATE, /**< Used in numpy.conjugate() impl */
120+
DPNP_FN_CHOOSE, /**< Used in numpy.choose() impl */
121+
DPNP_FN_CHOOSE_EXT, /**< Used in numpy.choose() impl, requires extra
122+
parameters */
123+
DPNP_FN_COPY, /**< Used in numpy.copy() impl */
126124
DPNP_FN_COPY_EXT, /**< Used in numpy.copy() impl, requires extra parameters
127125
*/
128126
DPNP_FN_COPYSIGN, /**< Used in numpy.copysign() impl */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,28 +1002,17 @@ constexpr auto dispatch_fmod_op(T elem1, T elem2)
10021002

10031003
static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
10041004
{
1005-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_INT][eft_INT] = {
1005+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_INT][eft_INT] = {
10061006
eft_INT, (void *)dpnp_copy_c_default<int32_t>};
1007-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_LNG][eft_LNG] = {
1007+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_LNG][eft_LNG] = {
10081008
eft_LNG, (void *)dpnp_copy_c_default<int64_t>};
1009-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_FLT][eft_FLT] = {
1009+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_FLT][eft_FLT] = {
10101010
eft_FLT, (void *)dpnp_copy_c_default<float>};
1011-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_DBL][eft_DBL] = {
1011+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_DBL][eft_DBL] = {
10121012
eft_DBL, (void *)dpnp_copy_c_default<double>};
1013-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_C128][eft_C128] = {
1013+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_C128][eft_C128] = {
10141014
eft_C128, (void *)dpnp_conjugate_c_default<std::complex<double>>};
10151015

1016-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_INT][eft_INT] = {
1017-
eft_INT, (void *)dpnp_copy_c_ext<int32_t>};
1018-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_LNG][eft_LNG] = {
1019-
eft_LNG, (void *)dpnp_copy_c_ext<int64_t>};
1020-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_FLT][eft_FLT] = {
1021-
eft_FLT, (void *)dpnp_copy_c_ext<float>};
1022-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_DBL][eft_DBL] = {
1023-
eft_DBL, (void *)dpnp_copy_c_ext<double>};
1024-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_C128][eft_C128] = {
1025-
eft_C128, (void *)dpnp_conjugate_c_ext<std::complex<double>>};
1026-
10271016
fmap[DPNPFuncName::DPNP_FN_COPY][eft_BLN][eft_BLN] = {
10281017
eft_BLN, (void *)dpnp_copy_c_default<bool>};
10291018
fmap[DPNPFuncName::DPNP_FN_COPY][eft_INT][eft_INT] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
6868
DPNP_FN_CHOLESKY_EXT
6969
DPNP_FN_CHOOSE
7070
DPNP_FN_CHOOSE_EXT
71-
DPNP_FN_CONJIGUATE
72-
DPNP_FN_CONJIGUATE_EXT
7371
DPNP_FN_COPY
7472
DPNP_FN_COPY_EXT
7573
DPNP_FN_COPYSIGN

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ __all__ += [
3939
"dpnp_absolute",
4040
"dpnp_arctan2",
4141
"dpnp_around",
42-
"dpnp_conjugate",
4342
"dpnp_copysign",
4443
"dpnp_cross",
4544
"dpnp_cumprod",
@@ -155,10 +154,6 @@ cpdef utils.dpnp_descriptor dpnp_around(utils.dpnp_descriptor x1, int decimals):
155154
return result
156155

157156

158-
cpdef utils.dpnp_descriptor dpnp_conjugate(utils.dpnp_descriptor x1):
159-
return call_fptr_1in_1out_strides(DPNP_FN_CONJIGUATE_EXT, x1)
160-
161-
162157
cpdef utils.dpnp_descriptor dpnp_copysign(utils.dpnp_descriptor x1_obj,
163158
utils.dpnp_descriptor x2_obj,
164159
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"dpnp_bitwise_or",
4848
"dpnp_bitwise_xor",
4949
"dpnp_ceil",
50+
"dpnp_conj",
5051
"dpnp_cos",
5152
"dpnp_divide",
5253
"dpnp_equal",
@@ -419,20 +420,41 @@ def dpnp_ceil(x, out=None, order="K"):
419420
"""
420421

421422

422-
def _call_cos(src, dst, sycl_queue, depends=None):
423+
_conj_docstring = """
424+
conj(x, out=None, order='K')
425+
426+
Computes conjugate for each element `x_i` for input array `x`.
427+
428+
Args:
429+
x (dpnp.ndarray):
430+
Input array, expected to have numeric data type.
431+
out ({None, dpnp.ndarray}, optional):
432+
Output array to populate. Array must have the correct
433+
shape and the expected data type.
434+
order ("C","F","A","K", optional): memory layout of the new
435+
output array, if parameter `out` is `None`.
436+
Default: "K".
437+
Return:
438+
dpnp.ndarray:
439+
An array containing the element-wise conjugate.
440+
The returned array has the same data type as `x`.
441+
"""
442+
443+
444+
def _call_conj(src, dst, sycl_queue, depends=None):
423445
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
424446

425447
if depends is None:
426448
depends = []
427449

428-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
429-
# call pybind11 extension for cos() function from OneMKL VM
430-
return vmi._cos(sycl_queue, src, dst, depends)
431-
return ti._cos(src, dst, sycl_queue, depends)
450+
if vmi._mkl_conj_to_call(sycl_queue, src, dst):
451+
# call pybind11 extension for conj() function from OneMKL VM
452+
return vmi._conj(sycl_queue, src, dst, depends)
453+
return ti._conj(src, dst, sycl_queue, depends)
432454

433455

434-
cos_func = UnaryElementwiseFunc(
435-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
456+
conj_func = UnaryElementwiseFunc(
457+
"conj", ti._conj_result_type, _call_conj, _conj_docstring
436458
)
437459

438460

@@ -441,13 +463,42 @@ def dpnp_cos(x, out=None, order="K"):
441463
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
442464
443465
Otherwise fully relies on dpctl.tensor implementation for cos() function.
466+
467+
"""
468+
469+
def _call_cos(src, dst, sycl_queue, depends=None):
470+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
471+
472+
if depends is None:
473+
depends = []
474+
475+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
476+
# call pybind11 extension for cos() function from OneMKL VM
477+
return vmi._cos(sycl_queue, src, dst, depends)
478+
return ti._cos(src, dst, sycl_queue, depends)
479+
480+
# dpctl.tensor only works with usm_ndarray
481+
x1_usm = dpnp.get_usm_ndarray(x)
482+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
483+
484+
func = UnaryElementwiseFunc(
485+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
486+
)
487+
res_usm = func(x1_usm, out=out_usm, order=order)
488+
return dpnp_array._create_from_usm_ndarray(res_usm)
489+
490+
491+
def dpnp_conj(x, out=None, order="K"):
444492
"""
493+
Invokes conj() function from pybind11 extension of OneMKL VM if possible.
445494
495+
Otherwise fully relies on dpctl.tensor implementation for conj() function.
496+
"""
446497
# dpctl.tensor only works with usm_ndarray
447498
x1_usm = dpnp.get_usm_ndarray(x)
448499
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
449500

450-
res_usm = cos_func(x1_usm, out=out_usm, order=order)
501+
res_usm = conj_func(x1_usm, out=out_usm, order=order)
451502
return dpnp_array._create_from_usm_ndarray(res_usm)
452503

453504

dpnp/dpnp_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def conj(self):
622622
623623
"""
624624

625-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
625+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
626626
return self
627627
else:
628628
return dpnp.conjugate(self)
@@ -635,7 +635,7 @@ def conjugate(self):
635635
636636
"""
637637

638-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
638+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
639639
return self
640640
else:
641641
return dpnp.conjugate(self)

0 commit comments

Comments
 (0)