File tree Expand file tree Collapse file tree 13 files changed +47
-13
lines changed
libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 13 files changed +47
-13
lines changed Original file line number Diff line number Diff line change @@ -64,7 +64,7 @@ set_source_files_properties(
6464if (UNIX )
6565 set_source_files_properties (
6666 ${CMAKE_CURRENT_SOURCE_DIR} /libtensor/source /elementwise_functions.cpp
67- PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES" )
67+ PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX " )
6868endif ()
6969target_compile_options (${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
7070target_link_options (${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acos
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -114,7 +116,8 @@ template <typename argT, typename resT> struct AcosFunctor
114116 }
115117
116118 /* ordinary cases */
117- return std::acos (in);
119+ return cmplx_ns::acos (
120+ cmplx_ns::complex <realT>(in)); // std::acos(in);
118121 }
119122 else {
120123 static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acosh
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -118,7 +120,8 @@ template <typename argT, typename resT> struct AcoshFunctor
118120 }
119121 else {
120122 /* ordinary cases */
121- acos_in = std::acos (in);
123+ acos_in = cmplx_ns::acos (
124+ cmplx_ns::complex <realT>(in)); // std::acos(in);
122125 }
123126
124127 /* Now we calculate acosh(z) */
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asin
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -134,7 +136,8 @@ template <typename argT, typename resT> struct AsinFunctor
134136 return resT{asinh_im, asinh_re};
135137 }
136138 /* ordinary cases */
137- return std::asin (in);
139+ return cmplx_ns::asin (
140+ cmplx_ns::complex <realT>(in)); // std::asin(in);
138141 }
139142 else {
140143 static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asinh
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -115,7 +117,8 @@ template <typename argT, typename resT> struct AsinhFunctor
115117 }
116118
117119 /* ordinary cases */
118- return std::asinh (in);
120+ return cmplx_ns::asinh (
121+ cmplx_ns::complex <realT>(in)); // std::asinh(in);
119122 }
120123 else {
121124 static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 2828#include < complex>
2929#include < cstddef>
3030#include < cstdint>
31+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3132#include < type_traits>
3233
3334#include " kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atan
4849
4950namespace py = pybind11;
5051namespace td_ns = dpctl::tensor::type_dispatch;
52+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5153
5254using dpctl::tensor::type_utils::is_complex;
5355
@@ -126,7 +128,8 @@ template <typename argT, typename resT> struct AtanFunctor
126128 return resT{atanh_im, atanh_re};
127129 }
128130 /* ordinary cases */
129- return std::atan (in);
131+ return cmplx_ns::atan (
132+ cmplx_ns::complex <realT>(in)); // std::atan(in);
130133 }
131134 else {
132135 static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 2828#include < complex>
2929#include < cstddef>
3030#include < cstdint>
31+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3132#include < type_traits>
3233
3334#include " kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atanh
4849
4950namespace py = pybind11;
5051namespace td_ns = dpctl::tensor::type_dispatch;
52+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5153
5254using dpctl::tensor::type_utils::is_complex;
5355
@@ -119,7 +121,8 @@ template <typename argT, typename resT> struct AtanhFunctor
119121 return resT{res_re, res_im};
120122 }
121123 /* ordinary cases */
122- return std::atanh (in);
124+ return cmplx_ns::atanh (
125+ cmplx_ns::complex <realT>(in)); // std::atanh(in);
123126 }
124127 else {
125128 static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cos
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CosFunctor
8183 * real and imaginary parts of input are finite.
8284 */
8385 if (in_re_finite && in_im_finite) {
84- return std::cos (in);
86+ return cmplx_ns::cos (
87+ cmplx_ns::complex <realT>(in)); // std::cos(in);
8588 }
8689
8790 /*
Original file line number Diff line number Diff line change 2727#include < cmath>
2828#include < cstddef>
2929#include < cstdint>
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cosh
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CoshFunctor
8183 * real and imaginary parts of input are finite.
8284 */
8385 if (xfinite && yfinite) {
84- return std::cosh (in);
86+ return cmplx_ns::cosh (
87+ cmplx_ns::complex <realT>(in)); // std::cosh(in);
8588 }
8689
8790 /*
Original file line number Diff line number Diff line change 2323// ===---------------------------------------------------------------------===//
2424
2525#pragma once
26- #include < CL/sycl.hpp>
2726#include < cmath>
2827#include < cstddef>
2928#include < cstdint>
29+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+ #include < sycl/sycl.hpp>
3031#include < type_traits>
3132
3233#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace sin
4748
4849namespace py = pybind11;
4950namespace td_ns = dpctl::tensor::type_dispatch;
51+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052
5153using dpctl::tensor::type_utils::is_complex;
5254
@@ -79,7 +81,8 @@ template <typename argT, typename resT> struct SinFunctor
7981 * real and imaginary parts of input are finite.
8082 */
8183 if (in_re_finite && in_im_finite) {
82- return std::sin (in);
84+ return cmplx_ns::sin (
85+ cmplx_ns::complex <realT>(in)); // std::sin(in);
8386 }
8487
8588 /*
You can’t perform that action at this time.
0 commit comments