Skip to content

Commit 36f26aa

Browse files
authored
Merge pull request #1522 from IntelPython/create_Unary_BinaryElementwisefunc_during_module_import
create_Unary_BinaryElementwisefunc_during_module_import
2 parents 18da200 + 24116e1 commit 36f26aa

File tree

8 files changed

+814
-278
lines changed

8 files changed

+814
-278
lines changed

dpnp/backend/extensions/vm/add.hpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event add_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::add(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct AddContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::AddOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return add_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/mul.hpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event mul_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::mul(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct MulContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::MulOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return mul_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/sub.hpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event sub_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::sub(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct SubContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::SubOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return sub_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

+77-2
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ namespace types
4545
{
4646
/**
4747
* @brief A factory to define pairs of supported types for which
48-
* MKL VM library provides support in oneapi::mkl::vm::div<T> function.
48+
* MKL VM library provides support in oneapi::mkl::vm::add<T> function.
4949
*
5050
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
5151
*/
5252
template <typename T>
53-
struct DivOutputType
53+
struct AddOutputType
5454
{
5555
using value_type = typename std::disjunction<
5656
dpctl_td_ns::BinaryTypeMapResultEntry<T,
@@ -102,6 +102,31 @@ struct CosOutputType
102102
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
103103
};
104104

105+
/**
106+
* @brief A factory to define pairs of supported types for which
107+
* MKL VM library provides support in oneapi::mkl::vm::div<T> function.
108+
*
109+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
110+
*/
111+
template <typename T>
112+
struct DivOutputType
113+
{
114+
using value_type = typename std::disjunction<
115+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
116+
std::complex<double>,
117+
T,
118+
std::complex<double>,
119+
std::complex<double>>,
120+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
121+
std::complex<float>,
122+
T,
123+
std::complex<float>,
124+
std::complex<float>>,
125+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
126+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
127+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
128+
};
129+
105130
/**
106131
* @brief A factory to define pairs of supported types for which
107132
* MKL VM library provides support in oneapi::mkl::vm::floor<T> function.
@@ -136,6 +161,31 @@ struct LnOutputType
136161
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
137162
};
138163

164+
/**
165+
* @brief A factory to define pairs of supported types for which
166+
* MKL VM library provides support in oneapi::mkl::vm::mul<T> function.
167+
*
168+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
169+
*/
170+
template <typename T>
171+
struct MulOutputType
172+
{
173+
using value_type = typename std::disjunction<
174+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
175+
std::complex<double>,
176+
T,
177+
std::complex<double>,
178+
std::complex<double>>,
179+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
180+
std::complex<float>,
181+
T,
182+
std::complex<float>,
183+
std::complex<float>>,
184+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
185+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
186+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
187+
};
188+
139189
/**
140190
* @brief A factory to define pairs of supported types for which
141191
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.
@@ -189,6 +239,31 @@ struct SqrtOutputType
189239
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
190240
};
191241

242+
/**
243+
* @brief A factory to define pairs of supported types for which
244+
* MKL VM library provides support in oneapi::mkl::vm::sub<T> function.
245+
*
246+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
247+
*/
248+
template <typename T>
249+
struct SubOutputType
250+
{
251+
using value_type = typename std::disjunction<
252+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
253+
std::complex<double>,
254+
T,
255+
std::complex<double>,
256+
std::complex<double>>,
257+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
258+
std::complex<float>,
259+
T,
260+
std::complex<float>,
261+
std::complex<float>>,
262+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
263+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
264+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
265+
};
266+
192267
/**
193268
* @brief A factory to define pairs of supported types for which
194269
* MKL VM library provides support in oneapi::mkl::vm::trunc<T> function.

0 commit comments

Comments
 (0)