|
45 | 45 | #include "cosh.hpp" |
46 | 46 | #include "div.hpp" |
47 | 47 | #include "floor.hpp" |
| 48 | +#include "hypot.hpp" |
48 | 49 | #include "ln.hpp" |
49 | 50 | #include "mul.hpp" |
50 | 51 | #include "pow.hpp" |
@@ -74,11 +75,12 @@ static unary_impl_fn_ptr_t atan_dispatch_vector[dpctl_td_ns::num_types]; |
74 | 75 | static binary_impl_fn_ptr_t atan2_dispatch_vector[dpctl_td_ns::num_types]; |
75 | 76 | static unary_impl_fn_ptr_t atanh_dispatch_vector[dpctl_td_ns::num_types]; |
76 | 77 | static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types]; |
| 78 | +static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types]; |
77 | 79 | static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types]; |
78 | 80 | static unary_impl_fn_ptr_t cosh_dispatch_vector[dpctl_td_ns::num_types]; |
79 | 81 | static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; |
80 | 82 | static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types]; |
81 | | -static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types]; |
| 83 | +static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types]; |
82 | 84 | static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types]; |
83 | 85 | static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; |
84 | 86 | static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types]; |
@@ -494,6 +496,35 @@ PYBIND11_MODULE(_vm_impl, m) |
494 | 496 | py::arg("sycl_queue"), py::arg("src"), py::arg("dst")); |
495 | 497 | } |
496 | 498 |
|
| 499 | + // BinaryUfunc: ==== Hypot(x1, x2) ==== |
| 500 | + { |
| 501 | + vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t, |
| 502 | + vm_ext::HypotContigFactory>( |
| 503 | + hypot_dispatch_vector); |
| 504 | + |
| 505 | + auto hypot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, |
| 506 | + arrayT dst, const event_vecT &depends = {}) { |
| 507 | + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, |
| 508 | + hypot_dispatch_vector); |
| 509 | + }; |
| 510 | + m.def("_hypot", hypot_pyapi, |
| 511 | + "Call `hypot` function from OneMKL VM library to compute element " |
| 512 | + "by element hypotenuse of `x`", |
| 513 | + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), |
| 514 | + py::arg("dst"), py::arg("depends") = py::list()); |
| 515 | + |
| 516 | + auto hypot_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, |
| 517 | + arrayT src2, arrayT dst) { |
| 518 | + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, |
| 519 | + hypot_dispatch_vector); |
| 520 | + }; |
| 521 | + m.def("_mkl_hypot_to_call", hypot_need_to_call_pyapi, |
| 522 | + "Check input arguments to answer if `hypot` function from " |
| 523 | + "OneMKL VM library can be used", |
| 524 | + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), |
| 525 | + py::arg("dst")); |
| 526 | + } |
| 527 | + |
497 | 528 | // UnaryUfunc: ==== Ln(x) ==== |
498 | 529 | { |
499 | 530 | vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t, |
|
0 commit comments