@@ -49,15 +49,14 @@ using vm_ext::binary_impl_fn_ptr_t;
49
49
using vm_ext::unary_impl_fn_ptr_t ;
50
50
51
51
static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types];
52
- static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
53
- static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
54
- static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types];
55
-
56
52
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
53
+ static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
57
54
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
55
+ static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
58
56
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
59
57
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
60
58
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
59
+ static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types];
61
60
62
61
PYBIND11_MODULE (_vm_impl, m)
63
62
{
@@ -94,8 +93,33 @@ PYBIND11_MODULE(_vm_impl, m)
94
93
py::arg (" dst" ));
95
94
}
96
95
97
- using arrayT = dpctl::tensor::usm_ndarray;
98
- using event_vecT = std::vector<sycl::event>;
96
+ // UnaryUfunc: ==== Cos(x) ====
97
+ {
98
+ vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
99
+ vm_ext::CosContigFactory>(
100
+ cos_dispatch_vector);
101
+
102
+ auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
103
+ const event_vecT &depends = {}) {
104
+ return vm_ext::unary_ufunc (exec_q, src, dst, depends,
105
+ cos_dispatch_vector);
106
+ };
107
+ m.def (" _cos" , cos_pyapi,
108
+ " Call `cos` function from OneMKL VM library to compute "
109
+ " cosine of vector elements" ,
110
+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
111
+ py::arg (" depends" ) = py::list ());
112
+
113
+ auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
114
+ arrayT dst) {
115
+ return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
116
+ cos_dispatch_vector);
117
+ };
118
+ m.def (" _mkl_cos_to_call" , cos_need_to_call_pyapi,
119
+ " Check input arguments to answer if `cos` function from "
120
+ " OneMKL VM library can be used" ,
121
+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
122
+ }
99
123
100
124
// BinaryUfunc: ==== Div(x1, x2) ====
101
125
{
@@ -127,8 +151,33 @@ PYBIND11_MODULE(_vm_impl, m)
127
151
py::arg (" dst" ));
128
152
}
129
153
130
- using arrayT = dpctl::tensor::usm_ndarray;
131
- using event_vecT = std::vector<sycl::event>;
154
+ // UnaryUfunc: ==== Ln(x) ====
155
+ {
156
+ vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
157
+ vm_ext::LnContigFactory>(
158
+ ln_dispatch_vector);
159
+
160
+ auto ln_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
161
+ const event_vecT &depends = {}) {
162
+ return vm_ext::unary_ufunc (exec_q, src, dst, depends,
163
+ ln_dispatch_vector);
164
+ };
165
+ m.def (" _ln" , ln_pyapi,
166
+ " Call `ln` function from OneMKL VM library to compute "
167
+ " natural logarithm of vector elements" ,
168
+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
169
+ py::arg (" depends" ) = py::list ());
170
+
171
+ auto ln_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
172
+ arrayT dst) {
173
+ return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
174
+ ln_dispatch_vector);
175
+ };
176
+ m.def (" _mkl_ln_to_call" , ln_need_to_call_pyapi,
177
+ " Check input arguments to answer if `ln` function from "
178
+ " OneMKL VM library can be used" ,
179
+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
180
+ }
132
181
133
182
// BinaryUfunc: ==== Mul(x1, x2) ====
134
183
{
@@ -160,95 +209,6 @@ PYBIND11_MODULE(_vm_impl, m)
160
209
py::arg (" dst" ));
161
210
}
162
211
163
- using arrayT = dpctl::tensor::usm_ndarray;
164
- using event_vecT = std::vector<sycl::event>;
165
-
166
- // BinaryUfunc: ==== Sub(x1, x2) ====
167
- {
168
- vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t ,
169
- vm_ext::SubContigFactory>(
170
- sub_dispatch_vector);
171
-
172
- auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
173
- arrayT dst, const event_vecT &depends = {}) {
174
- return vm_ext::binary_ufunc (exec_q, src1, src2, dst, depends,
175
- sub_dispatch_vector);
176
- };
177
- m.def (" _sub" , sub_pyapi,
178
- " Call `sub` function from OneMKL VM library to performs element "
179
- " by element subtraction of vector `src1` by vector `src2` "
180
- " to resulting vector `dst`" ,
181
- py::arg (" sycl_queue" ), py::arg (" src1" ), py::arg (" src2" ),
182
- py::arg (" dst" ), py::arg (" depends" ) = py::list ());
183
-
184
- auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
185
- arrayT src2, arrayT dst) {
186
- return vm_ext::need_to_call_binary_ufunc (exec_q, src1, src2, dst,
187
- sub_dispatch_vector);
188
- };
189
- m.def (" _mkl_sub_to_call" , sub_need_to_call_pyapi,
190
- " Check input arguments to answer if `sub` function from "
191
- " OneMKL VM library can be used" ,
192
- py::arg (" sycl_queue" ), py::arg (" src1" ), py::arg (" src2" ),
193
- py::arg (" dst" ));
194
- }
195
-
196
- // UnaryUfunc: ==== Cos(x) ====
197
- {
198
- vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
199
- vm_ext::CosContigFactory>(
200
- cos_dispatch_vector);
201
-
202
- auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
203
- const event_vecT &depends = {}) {
204
- return vm_ext::unary_ufunc (exec_q, src, dst, depends,
205
- cos_dispatch_vector);
206
- };
207
- m.def (" _cos" , cos_pyapi,
208
- " Call `cos` function from OneMKL VM library to compute "
209
- " cosine of vector elements" ,
210
- py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
211
- py::arg (" depends" ) = py::list ());
212
-
213
- auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
214
- arrayT dst) {
215
- return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
216
- cos_dispatch_vector);
217
- };
218
- m.def (" _mkl_cos_to_call" , cos_need_to_call_pyapi,
219
- " Check input arguments to answer if `cos` function from "
220
- " OneMKL VM library can be used" ,
221
- py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
222
- }
223
-
224
- // UnaryUfunc: ==== Ln(x) ====
225
- {
226
- vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
227
- vm_ext::LnContigFactory>(
228
- ln_dispatch_vector);
229
-
230
- auto ln_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
231
- const event_vecT &depends = {}) {
232
- return vm_ext::unary_ufunc (exec_q, src, dst, depends,
233
- ln_dispatch_vector);
234
- };
235
- m.def (" _ln" , ln_pyapi,
236
- " Call `ln` function from OneMKL VM library to compute "
237
- " natural logarithm of vector elements" ,
238
- py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
239
- py::arg (" depends" ) = py::list ());
240
-
241
- auto ln_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
242
- arrayT dst) {
243
- return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
244
- ln_dispatch_vector);
245
- };
246
- m.def (" _mkl_ln_to_call" , ln_need_to_call_pyapi,
247
- " Check input arguments to answer if `ln` function from "
248
- " OneMKL VM library can be used" ,
249
- py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
250
- }
251
-
252
212
// UnaryUfunc: ==== Sin(x) ====
253
213
{
254
214
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
@@ -335,4 +295,34 @@ PYBIND11_MODULE(_vm_impl, m)
335
295
" OneMKL VM library can be used" ,
336
296
py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
337
297
}
298
+
299
+ // BinaryUfunc: ==== Sub(x1, x2) ====
300
+ {
301
+ vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t ,
302
+ vm_ext::SubContigFactory>(
303
+ sub_dispatch_vector);
304
+
305
+ auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
306
+ arrayT dst, const event_vecT &depends = {}) {
307
+ return vm_ext::binary_ufunc (exec_q, src1, src2, dst, depends,
308
+ sub_dispatch_vector);
309
+ };
310
+ m.def (" _sub" , sub_pyapi,
311
+ " Call `sub` function from OneMKL VM library to performs element "
312
+ " by element subtraction of vector `src1` by vector `src2` "
313
+ " to resulting vector `dst`" ,
314
+ py::arg (" sycl_queue" ), py::arg (" src1" ), py::arg (" src2" ),
315
+ py::arg (" dst" ), py::arg (" depends" ) = py::list ());
316
+
317
+ auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
318
+ arrayT src2, arrayT dst) {
319
+ return vm_ext::need_to_call_binary_ufunc (exec_q, src1, src2, dst,
320
+ sub_dispatch_vector);
321
+ };
322
+ m.def (" _mkl_sub_to_call" , sub_need_to_call_pyapi,
323
+ " Check input arguments to answer if `sub` function from "
324
+ " OneMKL VM library can be used" ,
325
+ py::arg (" sycl_queue" ), py::arg (" src1" ), py::arg (" src2" ),
326
+ py::arg (" dst" ));
327
+ }
338
328
}
0 commit comments