Skip to content

Commit b4aa012

Browse files
committed
address comments
1 parent cf7df58 commit b4aa012

File tree

2 files changed

+138
-148
lines changed

2 files changed

+138
-148
lines changed

dpnp/backend/extensions/vm/types_matrix.hpp

+51-51
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ struct AddOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.
74+
*
75+
* @tparam T Type of input vector `a` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct CosOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::
82+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
83+
dpctl_td_ns::
84+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
85+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
86+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
87+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
88+
};
89+
7190
/**
7291
* @brief A factory to define pairs of supported types for which
7392
* MKL VM library provides support in oneapi::mkl::vm::div<T> function.
@@ -95,37 +114,31 @@ struct DivOutputType
95114

96115
/**
97116
* @brief A factory to define pairs of supported types for which
98-
* MKL VM library provides support in oneapi::mkl::vm::mul<T> function.
117+
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
99118
*
100-
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
119+
* @tparam T Type of input vector `a` and of result vector `y`.
101120
*/
102121
template <typename T>
103-
struct MulOutputType
122+
struct LnOutputType
104123
{
105124
using value_type = typename std::disjunction<
106-
dpctl_td_ns::BinaryTypeMapResultEntry<T,
107-
std::complex<double>,
108-
T,
109-
std::complex<double>,
110-
std::complex<double>>,
111-
dpctl_td_ns::BinaryTypeMapResultEntry<T,
112-
std::complex<float>,
113-
T,
114-
std::complex<float>,
115-
std::complex<float>>,
116-
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
117-
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
125+
dpctl_td_ns::
126+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
127+
dpctl_td_ns::
128+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
129+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
130+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
118131
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
119132
};
120133

121134
/**
122135
* @brief A factory to define pairs of supported types for which
123-
* MKL VM library provides support in oneapi::mkl::vm::sub<T> function.
136+
* MKL VM library provides support in oneapi::mkl::vm::mul<T> function.
124137
*
125138
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
126139
*/
127140
template <typename T>
128-
struct SubOutputType
141+
struct MulOutputType
129142
{
130143
using value_type = typename std::disjunction<
131144
dpctl_td_ns::BinaryTypeMapResultEntry<T,
@@ -145,12 +158,12 @@ struct SubOutputType
145158

146159
/**
147160
* @brief A factory to define pairs of supported types for which
148-
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.
161+
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.
149162
*
150163
* @tparam T Type of input vector `a` and of result vector `y`.
151164
*/
152165
template <typename T>
153-
struct CosOutputType
166+
struct SinOutputType
154167
{
155168
using value_type = typename std::disjunction<
156169
dpctl_td_ns::
@@ -164,31 +177,27 @@ struct CosOutputType
164177

165178
/**
166179
* @brief A factory to define pairs of supported types for which
167-
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
180+
* MKL VM library provides support in oneapi::mkl::vm::sqr<T> function.
168181
*
169182
* @tparam T Type of input vector `a` and of result vector `y`.
170183
*/
171184
template <typename T>
172-
struct LnOutputType
185+
struct SqrOutputType
173186
{
174187
using value_type = typename std::disjunction<
175-
dpctl_td_ns::
176-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
177-
dpctl_td_ns::
178-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
179188
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
180189
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
181190
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
182191
};
183192

184193
/**
185194
* @brief A factory to define pairs of supported types for which
186-
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.
195+
* MKL VM library provides support in oneapi::mkl::vm::sqrt<T> function.
187196
*
188197
* @tparam T Type of input vector `a` and of result vector `y`.
189198
*/
190199
template <typename T>
191-
struct SinOutputType
200+
struct SqrtOutputType
192201
{
193202
using value_type = typename std::disjunction<
194203
dpctl_td_ns::
@@ -202,35 +211,26 @@ struct SinOutputType
202211

203212
/**
204213
* @brief A factory to define pairs of supported types for which
205-
* MKL VM library provides support in oneapi::mkl::vm::sqr<T> function.
206-
*
207-
* @tparam T Type of input vector `a` and of result vector `y`.
208-
*/
209-
template <typename T>
210-
struct SqrOutputType
211-
{
212-
using value_type = typename std::disjunction<
213-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
214-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
215-
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
216-
};
217-
218-
/**
219-
* @brief A factory to define pairs of supported types for which
220-
* MKL VM library provides support in oneapi::mkl::vm::sqrt<T> function.
214+
* MKL VM library provides support in oneapi::mkl::vm::sub<T> function.
221215
*
222-
* @tparam T Type of input vector `a` and of result vector `y`.
216+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
223217
*/
224218
template <typename T>
225-
struct SqrtOutputType
219+
struct SubOutputType
226220
{
227221
using value_type = typename std::disjunction<
228-
dpctl_td_ns::
229-
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
230-
dpctl_td_ns::
231-
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
232-
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
233-
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
222+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
223+
std::complex<double>,
224+
T,
225+
std::complex<double>,
226+
std::complex<double>>,
227+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
228+
std::complex<float>,
229+
T,
230+
std::complex<float>,
231+
std::complex<float>>,
232+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
233+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
234234
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
235235
};
236236
} // namespace types

dpnp/backend/extensions/vm/vm_py.cpp

+87-97
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ using vm_ext::binary_impl_fn_ptr_t;
4949
using vm_ext::unary_impl_fn_ptr_t;
5050

5151
static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types];
52-
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
53-
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
54-
static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types];
55-
5652
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
53+
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
5754
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
55+
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
5856
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
5957
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
6058
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
59+
static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types];
6160

6261
PYBIND11_MODULE(_vm_impl, m)
6362
{
@@ -94,8 +93,33 @@ PYBIND11_MODULE(_vm_impl, m)
9493
py::arg("dst"));
9594
}
9695

97-
using arrayT = dpctl::tensor::usm_ndarray;
98-
using event_vecT = std::vector<sycl::event>;
96+
// UnaryUfunc: ==== Cos(x) ====
97+
{
98+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
99+
vm_ext::CosContigFactory>(
100+
cos_dispatch_vector);
101+
102+
auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
103+
const event_vecT &depends = {}) {
104+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
105+
cos_dispatch_vector);
106+
};
107+
m.def("_cos", cos_pyapi,
108+
"Call `cos` function from OneMKL VM library to compute "
109+
"cosine of vector elements",
110+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
111+
py::arg("depends") = py::list());
112+
113+
auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
114+
arrayT dst) {
115+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
116+
cos_dispatch_vector);
117+
};
118+
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
119+
"Check input arguments to answer if `cos` function from "
120+
"OneMKL VM library can be used",
121+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
122+
}
99123

100124
// BinaryUfunc: ==== Div(x1, x2) ====
101125
{
@@ -127,8 +151,33 @@ PYBIND11_MODULE(_vm_impl, m)
127151
py::arg("dst"));
128152
}
129153

130-
using arrayT = dpctl::tensor::usm_ndarray;
131-
using event_vecT = std::vector<sycl::event>;
154+
// UnaryUfunc: ==== Ln(x) ====
155+
{
156+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
157+
vm_ext::LnContigFactory>(
158+
ln_dispatch_vector);
159+
160+
auto ln_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
161+
const event_vecT &depends = {}) {
162+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
163+
ln_dispatch_vector);
164+
};
165+
m.def("_ln", ln_pyapi,
166+
"Call `ln` function from OneMKL VM library to compute "
167+
"natural logarithm of vector elements",
168+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
169+
py::arg("depends") = py::list());
170+
171+
auto ln_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
172+
arrayT dst) {
173+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
174+
ln_dispatch_vector);
175+
};
176+
m.def("_mkl_ln_to_call", ln_need_to_call_pyapi,
177+
"Check input arguments to answer if `ln` function from "
178+
"OneMKL VM library can be used",
179+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
180+
}
132181

133182
// BinaryUfunc: ==== Mul(x1, x2) ====
134183
{
@@ -160,95 +209,6 @@ PYBIND11_MODULE(_vm_impl, m)
160209
py::arg("dst"));
161210
}
162211

163-
using arrayT = dpctl::tensor::usm_ndarray;
164-
using event_vecT = std::vector<sycl::event>;
165-
166-
// BinaryUfunc: ==== Sub(x1, x2) ====
167-
{
168-
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
169-
vm_ext::SubContigFactory>(
170-
sub_dispatch_vector);
171-
172-
auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
173-
arrayT dst, const event_vecT &depends = {}) {
174-
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
175-
sub_dispatch_vector);
176-
};
177-
m.def("_sub", sub_pyapi,
178-
"Call `sub` function from OneMKL VM library to performs element "
179-
"by element subtraction of vector `src1` by vector `src2` "
180-
"to resulting vector `dst`",
181-
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
182-
py::arg("dst"), py::arg("depends") = py::list());
183-
184-
auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
185-
arrayT src2, arrayT dst) {
186-
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
187-
sub_dispatch_vector);
188-
};
189-
m.def("_mkl_sub_to_call", sub_need_to_call_pyapi,
190-
"Check input arguments to answer if `sub` function from "
191-
"OneMKL VM library can be used",
192-
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
193-
py::arg("dst"));
194-
}
195-
196-
// UnaryUfunc: ==== Cos(x) ====
197-
{
198-
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
199-
vm_ext::CosContigFactory>(
200-
cos_dispatch_vector);
201-
202-
auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
203-
const event_vecT &depends = {}) {
204-
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
205-
cos_dispatch_vector);
206-
};
207-
m.def("_cos", cos_pyapi,
208-
"Call `cos` function from OneMKL VM library to compute "
209-
"cosine of vector elements",
210-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
211-
py::arg("depends") = py::list());
212-
213-
auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
214-
arrayT dst) {
215-
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
216-
cos_dispatch_vector);
217-
};
218-
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
219-
"Check input arguments to answer if `cos` function from "
220-
"OneMKL VM library can be used",
221-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
222-
}
223-
224-
// UnaryUfunc: ==== Ln(x) ====
225-
{
226-
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
227-
vm_ext::LnContigFactory>(
228-
ln_dispatch_vector);
229-
230-
auto ln_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
231-
const event_vecT &depends = {}) {
232-
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
233-
ln_dispatch_vector);
234-
};
235-
m.def("_ln", ln_pyapi,
236-
"Call `ln` function from OneMKL VM library to compute "
237-
"natural logarithm of vector elements",
238-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
239-
py::arg("depends") = py::list());
240-
241-
auto ln_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
242-
arrayT dst) {
243-
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
244-
ln_dispatch_vector);
245-
};
246-
m.def("_mkl_ln_to_call", ln_need_to_call_pyapi,
247-
"Check input arguments to answer if `ln` function from "
248-
"OneMKL VM library can be used",
249-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
250-
}
251-
252212
// UnaryUfunc: ==== Sin(x) ====
253213
{
254214
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -335,4 +295,34 @@ PYBIND11_MODULE(_vm_impl, m)
335295
"OneMKL VM library can be used",
336296
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
337297
}
298+
299+
// BinaryUfunc: ==== Sub(x1, x2) ====
300+
{
301+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
302+
vm_ext::SubContigFactory>(
303+
sub_dispatch_vector);
304+
305+
auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
306+
arrayT dst, const event_vecT &depends = {}) {
307+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
308+
sub_dispatch_vector);
309+
};
310+
m.def("_sub", sub_pyapi,
311+
"Call `sub` function from OneMKL VM library to performs element "
312+
"by element subtraction of vector `src1` by vector `src2` "
313+
"to resulting vector `dst`",
314+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
315+
py::arg("dst"), py::arg("depends") = py::list());
316+
317+
auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
318+
arrayT src2, arrayT dst) {
319+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
320+
sub_dispatch_vector);
321+
};
322+
m.def("_mkl_sub_to_call", sub_need_to_call_pyapi,
323+
"Check input arguments to answer if `sub` function from "
324+
"OneMKL VM library can be used",
325+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
326+
py::arg("dst"));
327+
}
338328
}

0 commit comments

Comments
 (0)