Skip to content

Commit 0df7721

Browse files
committed
Clip now consistently yields max where max < min
sycl::clamp would yield max or min depending on the platform A test has been added for this behavior
1 parent 4eb83aa commit 0df7721

File tree

2 files changed

+17
-33
lines changed

2 files changed

+17
-33
lines changed

dpctl/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using namespace dpctl::tensor::offset_utils;
5353

54-
template <typename T> T dpt_clamp(const T &x, const T &min, const T &max)
54+
template <typename T> T clip(const T &x, const T &min, const T &max)
5555
{
5656
using dpctl::tensor::type_utils::is_complex;
5757
if constexpr (is_complex<T>::value) {
@@ -105,8 +105,7 @@ template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
105105
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
106106
offset += sgSize)
107107
{
108-
dst_p[offset] =
109-
dpt_clamp(x_p[offset], min_p[offset], max_p[offset]);
108+
dst_p[offset] = clip(x_p[offset], min_p[offset], max_p[offset]);
110109
}
111110
}
112111
else {
@@ -142,32 +141,18 @@ template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
142141
x_vec = sg.load<vec_sz>(x_multi_ptr);
143142
min_vec = sg.load<vec_sz>(min_multi_ptr);
144143
max_vec = sg.load<vec_sz>(max_multi_ptr);
145-
if constexpr (std::is_floating_point_v<T> ||
146-
std::is_same_v<T, bool>) {
147144
#pragma unroll
148-
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id)
149-
{
150-
dst_vec[vec_id] =
151-
dpt_clamp(x_vec[vec_id], min_vec[vec_id],
152-
max_vec[vec_id]);
153-
}
154-
}
155-
else {
156-
dst_vec = sycl::clamp(x_vec, min_vec, max_vec);
145+
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
146+
dst_vec[vec_id] = clip(x_vec[vec_id], min_vec[vec_id],
147+
max_vec[vec_id]);
157148
}
158149
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
159150
}
160151
}
161152
else {
162153
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
163154
k += sgSize) {
164-
if constexpr (std::is_floating_point_v<T> ||
165-
std::is_same_v<T, bool>) {
166-
dst_p[k] = dpt_clamp(x_p[k], min_p[k], max_p[k]);
167-
}
168-
else {
169-
dst_p[k] = sycl::clamp(x_p[k], min_p[k], max_p[k]);
170-
}
155+
dst_p[k] = clip(x_p[k], min_p[k], max_p[k]);
171156
}
172157
}
173158
}
@@ -243,18 +228,9 @@ template <typename T, typename IndexerT> class ClipStridedFunctor
243228
{
244229
size_t gid = id[0];
245230
auto offsets = indexer(static_cast<py::ssize_t>(gid));
246-
if constexpr (std::is_integral_v<T> && !std::is_same_v<T, bool>) {
247-
dst_p[offsets.get_fourth_offset()] =
248-
sycl::clamp(x_p[offsets.get_first_offset()],
249-
min_p[offsets.get_second_offset()],
250-
max_p[offsets.get_third_offset()]);
251-
}
252-
else {
253-
dst_p[offsets.get_fourth_offset()] =
254-
dpt_clamp(x_p[offsets.get_first_offset()],
255-
min_p[offsets.get_second_offset()],
256-
max_p[offsets.get_third_offset()]);
257-
}
231+
dst_p[offsets.get_fourth_offset()] = clip(
232+
x_p[offsets.get_first_offset()], min_p[offsets.get_second_offset()],
233+
max_p[offsets.get_third_offset()]);
258234
}
259235
};
260236

dpctl/tests/test_tensor_clip.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,11 @@ def test_clip_strided(dt):
488488
a_max = a_max[::-2]
489489
r = dpt.clip(x, min=-3, max=a_max)
490490
assert dpt.all(a_max == r)
491+
492+
493+
def test_clip_max_less_than_min():
494+
get_queue_or_skip()
495+
496+
x = dpt.ones(10, dtype="i4")
497+
res = dpt.clip(x, 5, 0)
498+
assert dpt.all(res == 0)

0 commit comments

Comments
 (0)