Skip to content

Commit 69efc60

Browse files
joeatoddKornevNikita
authored andcommitted
[SYCL][COMPAT] Replace T{-1} with static_cast<T>(-1) for mask creation (#16527)
Also include `<limits>` header for `std::numeric_limits<unsigned char>::digits` (equal to `CHAR_BIT`), and tests for char types.
1 parent acfdff4 commit 69efc60

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#pragma once
3333

34+
#include <limits>
3435
#include <sycl/feature_test.hpp>
3536
#include <type_traits>
3637

@@ -308,7 +309,9 @@ inline T bfe(const T source, const uint32_t bit_start,
308309
// FIXME(syclcompat-lib-reviewers): This ternary was added to catch a case
309310
// which may be undefined anyway. Consider that we are losing perf here.
310311
const T mask =
311-
num_bits >= CHAR_BIT * sizeof(T) ? T{-1} : ((T{1} << num_bits) - 1);
312+
num_bits >= std::numeric_limits<unsigned char>::digits * sizeof(T)
313+
? static_cast<T>(-1)
314+
: ((static_cast<T>(1) << num_bits) - 1);
312315
return (source >> bit_start) & mask;
313316
}
314317

@@ -321,7 +324,7 @@ inline T bfe(const T source, const uint32_t bit_start,
321324
/// and source \param num_bits gives the bit field length in bits.
322325
///
323326
/// The result is padded with the sign bit of the extracted field. If `num_bits`
324-
/// is zero, the result is zero. If the start position is beyond the msb of the
327+
/// is zero, the result is zero. If the start position is beyond the msb of the
325328
/// input, the result is filled with the replicated sign bit of the extracted
326329
/// field.
327330
///
@@ -363,7 +366,8 @@ inline T bfe_safe(const T source, const uint32_t bit_start,
363366
return res;
364367
}
365368
#endif
366-
const uint32_t bit_width = CHAR_BIT * sizeof(T);
369+
const uint32_t bit_width =
370+
std::numeric_limits<unsigned char>::digits * sizeof(T);
367371
const uint32_t pos = std::min(bit_start, bit_width);
368372
const uint32_t len = std::min(pos + num_bits, bit_width) - pos;
369373
if constexpr (std::is_signed_v<T>) {
@@ -397,7 +401,8 @@ template <typename T>
397401
inline T bfi(const T x, const T y, const uint32_t bit_start,
398402
const uint32_t num_bits) {
399403
static_assert(std::is_unsigned_v<T>);
400-
constexpr unsigned bit_width = CHAR_BIT * sizeof(T);
404+
constexpr unsigned bit_width =
405+
std::numeric_limits<unsigned char>::digits * sizeof(T);
401406

402407
// if bit_start > bit_width || len == 0, should return y.
403408
const T ignore_bfi = static_cast<T>(bit_start > bit_width || num_bits == 0);
@@ -441,7 +446,8 @@ inline T bfi_safe(const T x, const T y, const uint32_t bit_start,
441446
return res;
442447
}
443448
#endif
444-
constexpr unsigned bit_width = CHAR_BIT * sizeof(T);
449+
constexpr unsigned bit_width =
450+
std::numeric_limits<unsigned char>::digits * sizeof(T);
445451
const uint32_t pos = std::min(bit_start, bit_width);
446452
const uint32_t len = std::min(pos + num_bits, bit_width) - pos;
447453
return syclcompat::detail::bfi(x, y, pos, len);

sycl/test-e2e/syclcompat/math/math_bfe.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
template <typename T>
4747
inline std::enable_if_t<std::is_integral_v<T>, T>
4848
bfe_slow(const T source, const uint32_t bit_start, const uint32_t num_bits) {
49-
const uint32_t msb = CHAR_BIT * sizeof(T) - 1;
49+
const uint32_t msb =
50+
std::numeric_limits<unsigned char>::digits * sizeof(T) - 1;
5051
const uint32_t pos = bit_start;
5152
const uint32_t len = num_bits;
5253

@@ -55,7 +56,8 @@ bfe_slow(const T source, const uint32_t bit_start, const uint32_t num_bits) {
5556
return 0ULL;
5657

5758
T sbit;
58-
std::bitset<CHAR_BIT * sizeof(T)> source_bitset(source);
59+
std::bitset<std::numeric_limits<unsigned char>::digits * sizeof(T)>
60+
source_bitset(source);
5961
if (std::is_unsigned_v<T> || len == 0)
6062
sbit = 0;
6163
else
@@ -67,16 +69,17 @@ bfe_slow(const T source, const uint32_t bit_start, const uint32_t num_bits) {
6769
if (bit_start > msb)
6870
return -sbit;
6971

70-
std::bitset<CHAR_BIT * sizeof(T)> result_bitset;
72+
std::bitset<std::numeric_limits<unsigned char>::digits * sizeof(T)>
73+
result_bitset;
7174
for (uint8_t i = 0; i <= msb; ++i)
7275
result_bitset[i] =
7376
(i < len && pos + i <= msb) ? source_bitset[pos + i] : sbit;
7477
return result_bitset.to_ullong();
7578
}
7679

7780
template <typename T> bool test(const char *Msg, int N) {
78-
uint32_t bit_width = CHAR_BIT * sizeof(T);
79-
T min_value = std::numeric_limits<T>::min();
81+
uint32_t bit_width = std::numeric_limits<unsigned char>::digits * sizeof(T);
82+
T min_value = std::numeric_limits<T>::lowest();
8083
T max_value = std::numeric_limits<T>::max();
8184
std::random_device rd;
8285
std::mt19937::result_type seed =
@@ -91,7 +94,9 @@ template <typename T> bool test(const char *Msg, int N) {
9194
.count());
9295

9396
std::mt19937 gen(seed);
94-
std::uniform_int_distribution<T> rd_source(min_value, max_value);
97+
// Support for char type with uniform_int_distribution isn't universal
98+
using RandomDataT = std::conditional_t<sizeof(T) == 1, int, T>;
99+
std::uniform_int_distribution<RandomDataT> rd_source(min_value, max_value);
95100

96101
// Define a small overshoot so that we adequately test out-of-range cases
97102
// without sacrificing depth of testing of valid start+length combinations
@@ -105,7 +110,7 @@ template <typename T> bool test(const char *Msg, int N) {
105110
std::vector<uint32_t> starts(N, 0);
106111
std::vector<uint32_t> lengths(N, 0);
107112
for (int i = 0; i < N; ++i) {
108-
sources[i] = rd_source(gen);
113+
sources[i] = static_cast<T>(rd_source(gen));
109114
starts[i] = rd_start(gen);
110115
lengths[i] = rd_length(gen);
111116
}
@@ -172,6 +177,8 @@ template <typename T> bool test(const char *Msg, int N) {
172177

173178
int main() {
174179
const int N = 1000;
180+
assert(test<int8_t>("int8", N));
181+
assert(test<uint8_t>("uint8", N));
175182
assert(test<int16_t>("int16", N));
176183
assert(test<uint16_t>("uint16", N));
177184
assert(test<int32_t>("int32", N));

0 commit comments

Comments
 (0)