Skip to content

Commit 27c7623

Browse files
Use oneapi extension for complexes for remaining elementwise functions
Used functions from sycl::ext::oneapi::experimental context to implement evaluation on data of complex type.
1 parent e84989b commit 27c7623

21 files changed

+105
-33
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/math_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/math_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+
#include <sycl/sycl.hpp>
3031
#include <type_traits>
3132

3233
#include "utils/offset_utils.hpp"
@@ -49,6 +50,7 @@ namespace multiply
4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
5152
namespace tu_ns = dpctl::tensor::type_utils;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
template <typename argT1, typename argT2, typename resT> struct MultiplyFunctor
5456
{
@@ -60,7 +62,18 @@ template <typename argT1, typename argT2, typename resT> struct MultiplyFunctor
6062

6163
resT operator()(const argT1 &in1, const argT2 &in2) const
6264
{
63-
return in1 * in2;
65+
if constexpr (tu_ns::is_complex<argT1>::value &&
66+
tu_ns::is_complex<argT2>::value)
67+
{
68+
using realT1 = typename argT1::value_type;
69+
using realT2 = typename argT2::value_type;
70+
71+
return exprm_ns::complex<realT1>(in1) *
72+
exprm_ns::complex<realT2>(in2);
73+
}
74+
else {
75+
return in1 * in2;
76+
}
6477
}
6578

6679
template <int vec_sz>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/offset_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
3029
#include <limits>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "utils/offset_utils.hpp"
@@ -49,6 +50,7 @@ namespace pow
4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
5152
namespace tu_ns = dpctl::tensor::type_utils;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
template <typename argT1, typename argT2, typename resT> struct PowFunctor
5456
{
@@ -83,6 +85,15 @@ template <typename argT1, typename argT2, typename resT> struct PowFunctor
8385
}
8486
return res;
8587
}
88+
else if constexpr (tu_ns::is_complex<argT1>::value &&
89+
tu_ns::is_complex<argT2>::value)
90+
{
91+
using realT1 = typename argT1::value_type;
92+
using realT2 = typename argT2::value_type;
93+
94+
return exprm_ns::pow(exprm_ns::complex<realT1>(in1),
95+
exprm_ns::complex<realT2>(in2));
96+
}
8697
else {
8798
return std::pow(in1, in2);
8899
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
3231
#include <limits>
32+
#include <sycl/sycl.hpp>
3333
#include <type_traits>
3434

3535
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
31+
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

3434
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#pragma once
28-
#include <CL/sycl.hpp>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp

+14-9
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <cstddef>
2928
#include <cstdint>
3029
#include <limits>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace sign
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355
using dpctl::tensor::type_utils::vec_cast;
@@ -61,38 +63,41 @@ template <typename argT, typename resT> struct SignFunctor
6163
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6264
using supports_sg_loadstore = std::false_type;
6365

64-
resT operator()(const argT &x) const
66+
resT operator()(const argT &in) const
6567
{
6668
if constexpr (std::is_integral_v<argT>) {
6769
if constexpr (std::is_unsigned_v<argT>) {
68-
return resT(0 < x);
70+
return resT(0 < in);
6971
}
7072
else {
71-
return sign<argT>(x);
73+
return sign_impl<argT>(in);
7274
}
7375
}
7476
else {
7577
if constexpr (is_complex<argT>::value) {
76-
if (x == argT(0)) {
78+
using realT = typename argT::value_type;
79+
80+
if (in == argT(0)) {
7781
return resT(0);
7882
}
7983
else {
80-
return (x / std::abs(x));
84+
auto z = exprm_ns::complex<realT>(in);
85+
return (z / exprm_ns::abs(z));
8186
}
8287
}
8388
else {
84-
if (std::isnan(x)) {
89+
if (std::isnan(in)) {
8590
return std::numeric_limits<resT>::quiet_NaN();
8691
}
8792
else {
88-
return sign<argT>(x);
93+
return sign_impl<argT>(in);
8994
}
9095
}
9196
}
9297
}
9398

9499
private:
95-
template <typename T> T sign(const T &v) const
100+
template <typename T> T sign_impl(const T &v) const
96101
{
97102
return (T(0) < v) - (v < T(0));
98103
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "utils/offset_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <cstddef>
2928
#include <cstdint>
3029
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

3333
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
3231
#include <limits>
32+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
33+
#include <sycl/sycl.hpp>
3334
#include <type_traits>
3435

3536
#include "kernels/elementwise_functions/common.hpp"
@@ -50,6 +51,7 @@ namespace sqrt
5051

5152
namespace py = pybind11;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -74,7 +76,10 @@ template <typename argT, typename resT> struct SqrtFunctor
7476
// #else
7577
// return std::sqrt(in);
7678
// #endif
77-
return csqrt(in);
79+
using realT = typename argT::value_type;
80+
81+
// return csqrt(in);
82+
return exprm_ns::sqrt(exprm_ns::complex<realT>(in));
7883
}
7984
else {
8085
return std::sqrt(in);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <cstddef>
3029
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace square
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355
using dpctl::tensor::type_utils::vec_cast;
@@ -68,7 +70,16 @@ template <typename argT, typename resT> struct SquareFunctor
6870

6971
resT operator()(const argT &in) const
7072
{
71-
return in * in;
73+
if constexpr (is_complex<argT>::value) {
74+
using realT = typename argT::value_type;
75+
76+
auto z = exprm_ns::complex<realT>(in);
77+
78+
return z * z;
79+
}
80+
else {
81+
return in * in;
82+
}
7283
}
7384

7485
template <int vec_sz>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

3232
#include "utils/offset_utils.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <complex>
2928
#include <cstddef>
3029
#include <cstdint>
3130
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

3434
#include "kernels/elementwise_functions/common.hpp"

dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
27-
#include <CL/sycl.hpp>
2827
#include <cmath>
2928
#include <complex>
3029
#include <cstddef>
3130
#include <cstdint>
3231
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
32+
#include <sycl/sycl.hpp>
3333
#include <type_traits>
3434

3535
#include "kernels/elementwise_functions/common.hpp"

0 commit comments

Comments
 (0)