@@ -51,7 +51,7 @@ namespace td_ns = dpctl::tensor::type_dispatch;
51
51
52
52
using namespace dpctl ::tensor::offset_utils;
53
53
54
- template <typename T> T dpt_clamp (const T &x, const T &min, const T &max)
54
+ template <typename T> T clip (const T &x, const T &min, const T &max)
55
55
{
56
56
using dpctl::tensor::type_utils::is_complex;
57
57
if constexpr (is_complex<T>::value) {
@@ -105,8 +105,7 @@ template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
105
105
offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
106
106
offset += sgSize)
107
107
{
108
- dst_p[offset] =
109
- dpt_clamp (x_p[offset], min_p[offset], max_p[offset]);
108
+ dst_p[offset] = clip (x_p[offset], min_p[offset], max_p[offset]);
110
109
}
111
110
}
112
111
else {
@@ -142,32 +141,18 @@ template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
142
141
x_vec = sg.load <vec_sz>(x_multi_ptr);
143
142
min_vec = sg.load <vec_sz>(min_multi_ptr);
144
143
max_vec = sg.load <vec_sz>(max_multi_ptr);
145
- if constexpr (std::is_floating_point_v<T> ||
146
- std::is_same_v<T, bool >) {
147
144
#pragma unroll
148
- for (std::uint8_t vec_id = 0 ; vec_id < vec_sz; ++vec_id)
149
- {
150
- dst_vec[vec_id] =
151
- dpt_clamp (x_vec[vec_id], min_vec[vec_id],
152
- max_vec[vec_id]);
153
- }
154
- }
155
- else {
156
- dst_vec = sycl::clamp (x_vec, min_vec, max_vec);
145
+ for (std::uint8_t vec_id = 0 ; vec_id < vec_sz; ++vec_id) {
146
+ dst_vec[vec_id] = clip (x_vec[vec_id], min_vec[vec_id],
147
+ max_vec[vec_id]);
157
148
}
158
149
sg.store <vec_sz>(dst_multi_ptr, dst_vec);
159
150
}
160
151
}
161
152
else {
162
153
for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
163
154
k += sgSize) {
164
- if constexpr (std::is_floating_point_v<T> ||
165
- std::is_same_v<T, bool >) {
166
- dst_p[k] = dpt_clamp (x_p[k], min_p[k], max_p[k]);
167
- }
168
- else {
169
- dst_p[k] = sycl::clamp (x_p[k], min_p[k], max_p[k]);
170
- }
155
+ dst_p[k] = clip (x_p[k], min_p[k], max_p[k]);
171
156
}
172
157
}
173
158
}
@@ -243,18 +228,9 @@ template <typename T, typename IndexerT> class ClipStridedFunctor
243
228
{
244
229
size_t gid = id[0 ];
245
230
auto offsets = indexer (static_cast <py::ssize_t >(gid));
246
- if constexpr (std::is_integral_v<T> && !std::is_same_v<T, bool >) {
247
- dst_p[offsets.get_fourth_offset ()] =
248
- sycl::clamp (x_p[offsets.get_first_offset ()],
249
- min_p[offsets.get_second_offset ()],
250
- max_p[offsets.get_third_offset ()]);
251
- }
252
- else {
253
- dst_p[offsets.get_fourth_offset ()] =
254
- dpt_clamp (x_p[offsets.get_first_offset ()],
255
- min_p[offsets.get_second_offset ()],
256
- max_p[offsets.get_third_offset ()]);
257
- }
231
+ dst_p[offsets.get_fourth_offset ()] = clip (
232
+ x_p[offsets.get_first_offset ()], min_p[offsets.get_second_offset ()],
233
+ max_p[offsets.get_third_offset ()]);
258
234
}
259
235
};
260
236
0 commit comments