Skip to content

Commit 189d2d5

Browse files
Merge branch 'master' into implement-clip
2 parents 2555cc6 + 1d57614 commit 189d2d5

File tree

12 files changed

+1609
-5
lines changed

12 files changed

+1609
-5
lines changed

dpctl/tensor/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,16 @@
111111
bitwise_or,
112112
bitwise_right_shift,
113113
bitwise_xor,
114+
cbrt,
114115
ceil,
115116
conj,
117+
copysign,
116118
cos,
117119
cosh,
118120
divide,
119121
equal,
120122
exp,
123+
exp2,
121124
expm1,
122125
floor,
123126
floor_divide,
@@ -150,6 +153,7 @@
150153
real,
151154
remainder,
152155
round,
156+
rsqrt,
153157
sign,
154158
signbit,
155159
sin,
@@ -316,4 +320,8 @@
316320
"argmin",
317321
"prod",
318322
"clip",
323+
"cbrt",
324+
"exp2",
325+
"copysign",
326+
"rsqrt",
319327
]

dpctl/tensor/_elementwise_funcs.py

+113
Original file line numberDiff line numberDiff line change
@@ -1761,3 +1761,116 @@
17611761
hypot = BinaryElementwiseFunc(
17621762
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
17631763
)
1764+
1765+
1766+
# U37: ==== CBRT (x)
1767+
_cbrt_docstring_ = """
1768+
cbrt(x, out=None, order='K')
1769+
1770+
Computes positive cube-root for each element `x_i` for input array `x`.
1771+
1772+
Args:
1773+
x (usm_ndarray):
1774+
Input array, expected to have a real floating-point data type.
1775+
out ({None, usm_ndarray}, optional):
1776+
Output array to populate.
1777+
Array have the correct shape and the expected data type.
1778+
order ("C","F","A","K", optional):
1779+
Memory layout of the newly output array, if parameter `out` is `None`.
1780+
Default: "K".
1781+
Returns:
1782+
usm_narray:
1783+
An array containing the element-wise positive cube-root.
1784+
The data type of the returned array is determined by
1785+
the Type Promotion Rules.
1786+
"""
1787+
1788+
cbrt = UnaryElementwiseFunc(
1789+
"cbrt", ti._cbrt_result_type, ti._cbrt, _cbrt_docstring_
1790+
)
1791+
1792+
1793+
# U38: ==== EXP2 (x)
1794+
_exp2_docstring_ = """
1795+
exp2(x, out=None, order='K')
1796+
1797+
Computes the base-2 exponential for each element `x_i` for input array `x`.
1798+
1799+
Args:
1800+
x (usm_ndarray):
1801+
Input array, expected to have a floating-point data type.
1802+
out ({None, usm_ndarray}, optional):
1803+
Output array to populate.
1804+
Array have the correct shape and the expected data type.
1805+
order ("C","F","A","K", optional):
1806+
Memory layout of the newly output array, if parameter `out` is `None`.
1807+
Default: "K".
1808+
Returns:
1809+
usm_narray:
1810+
An array containing the element-wise base-2 exponentials.
1811+
The data type of the returned array is determined by
1812+
the Type Promotion Rules.
1813+
"""
1814+
1815+
exp2 = UnaryElementwiseFunc(
1816+
"exp2", ti._exp2_result_type, ti._exp2, _exp2_docstring_
1817+
)
1818+
1819+
1820+
# B25: ==== COPYSIGN (x1, x2)
1821+
_copysign_docstring_ = """
1822+
copysign(x1, x2, out=None, order='K')
1823+
1824+
Composes a floating-point value with the magnitude of `x1_i` and the sign of
1825+
`x2_i` for each element of input arrays `x1` and `x2`.
1826+
1827+
Args:
1828+
x1 (usm_ndarray):
1829+
First input array, expected to have a real floating-point data type.
1830+
x2 (usm_ndarray):
1831+
Second input array, also expected to have a real floating-point data
1832+
type.
1833+
out ({None, usm_ndarray}, optional):
1834+
Output array to populate.
1835+
Array have the correct shape and the expected data type.
1836+
order ("C","F","A","K", optional):
1837+
Memory layout of the newly output array, if parameter `out` is `None`.
1838+
Default: "K".
1839+
Returns:
1840+
usm_narray:
1841+
An array containing the element-wise results. The data type
1842+
of the returned array is determined by the Type Promotion Rules.
1843+
"""
1844+
copysign = BinaryElementwiseFunc(
1845+
"copysign",
1846+
ti._copysign_result_type,
1847+
ti._copysign,
1848+
_copysign_docstring_,
1849+
)
1850+
1851+
1852+
# U39: ==== RSQRT (x)
1853+
_rsqrt_docstring_ = """
1854+
rsqrt(x, out=None, order='K')
1855+
1856+
Computes the reciprocal square-root for each element `x_i` for input array `x`.
1857+
1858+
Args:
1859+
x (usm_ndarray):
1860+
Input array, expected to have a real floating-point data type.
1861+
out ({None, usm_ndarray}, optional):
1862+
Output array to populate.
1863+
Array have the correct shape and the expected data type.
1864+
order ("C","F","A","K", optional):
1865+
Memory layout of the newly output array, if parameter `out` is `None`.
1866+
Default: "K".
1867+
Returns:
1868+
usm_narray:
1869+
An array containing the element-wise reciprocal square-root.
1870+
The data type of the returned array is determined by
1871+
the Type Promotion Rules.
1872+
"""
1873+
1874+
rsqrt = UnaryElementwiseFunc(
1875+
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
1876+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//=== cbrt.hpp - Unary function CBRT ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of CBRT(x)
23+
/// function that compute a square root.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "kernels/elementwise_functions/common.hpp"
34+
35+
#include "utils/offset_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace cbrt
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
52+
template <typename argT, typename resT> struct CbrtFunctor
53+
{
54+
55+
// is function constant for given argT
56+
using is_constant = typename std::false_type;
57+
// constant value, if constant
58+
// constexpr resT constant_value = resT{};
59+
// is function defined for sycl::vec
60+
using supports_vec = typename std::false_type;
61+
// do both argTy and resTy support sugroup store/load operation
62+
using supports_sg_loadstore = typename std::true_type;
63+
64+
resT operator()(const argT &in) const
65+
{
66+
return sycl::cbrt(in);
67+
}
68+
};
69+
70+
template <typename argTy,
71+
typename resTy = argTy,
72+
unsigned int vec_sz = 4,
73+
unsigned int n_vecs = 2>
74+
using CbrtContigFunctor = elementwise_common::
75+
UnaryContigFunctor<argTy, resTy, CbrtFunctor<argTy, resTy>, vec_sz, n_vecs>;
76+
77+
template <typename argTy, typename resTy, typename IndexerT>
78+
using CbrtStridedFunctor = elementwise_common::
79+
UnaryStridedFunctor<argTy, resTy, IndexerT, CbrtFunctor<argTy, resTy>>;
80+
81+
template <typename T> struct CbrtOutputType
82+
{
83+
using value_type = typename std::disjunction< // disjunction is C++17
84+
// feature, supported by DPC++
85+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
86+
td_ns::TypeMapResultEntry<T, float, float>,
87+
td_ns::TypeMapResultEntry<T, double, double>,
88+
td_ns::DefaultResultEntry<void>>::result_type;
89+
};
90+
91+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
92+
class cbrt_contig_kernel;
93+
94+
template <typename argTy>
95+
sycl::event cbrt_contig_impl(sycl::queue &exec_q,
96+
size_t nelems,
97+
const char *arg_p,
98+
char *res_p,
99+
const std::vector<sycl::event> &depends = {})
100+
{
101+
return elementwise_common::unary_contig_impl<
102+
argTy, CbrtOutputType, CbrtContigFunctor, cbrt_contig_kernel>(
103+
exec_q, nelems, arg_p, res_p, depends);
104+
}
105+
106+
template <typename fnT, typename T> struct CbrtContigFactory
107+
{
108+
fnT get()
109+
{
110+
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
111+
void>) {
112+
fnT fn = nullptr;
113+
return fn;
114+
}
115+
else {
116+
fnT fn = cbrt_contig_impl<T>;
117+
return fn;
118+
}
119+
}
120+
};
121+
122+
template <typename fnT, typename T> struct CbrtTypeMapFactory
123+
{
124+
/*! @brief get typeid for output type of std::cbrt(T x) */
125+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
126+
{
127+
using rT = typename CbrtOutputType<T>::value_type;
128+
return td_ns::GetTypeid<rT>{}.get();
129+
}
130+
};
131+
132+
template <typename T1, typename T2, typename T3> class cbrt_strided_kernel;
133+
134+
template <typename argTy>
135+
sycl::event
136+
cbrt_strided_impl(sycl::queue &exec_q,
137+
size_t nelems,
138+
int nd,
139+
const py::ssize_t *shape_and_strides,
140+
const char *arg_p,
141+
py::ssize_t arg_offset,
142+
char *res_p,
143+
py::ssize_t res_offset,
144+
const std::vector<sycl::event> &depends,
145+
const std::vector<sycl::event> &additional_depends)
146+
{
147+
return elementwise_common::unary_strided_impl<
148+
argTy, CbrtOutputType, CbrtStridedFunctor, cbrt_strided_kernel>(
149+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
150+
res_offset, depends, additional_depends);
151+
}
152+
153+
template <typename fnT, typename T> struct CbrtStridedFactory
154+
{
155+
fnT get()
156+
{
157+
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
158+
void>) {
159+
fnT fn = nullptr;
160+
return fn;
161+
}
162+
else {
163+
fnT fn = cbrt_strided_impl<T>;
164+
return fn;
165+
}
166+
}
167+
};
168+
169+
} // namespace cbrt
170+
} // namespace kernels
171+
} // namespace tensor
172+
} // namespace dpctl

0 commit comments

Comments
 (0)