Skip to content

Commit 5627794

Browse files
authored
MathExtras: avoid unnecessarily widening types (#95426)
Several multi-argument functions unnecessarily widen types beyond the argument types. Template'ize the functions, and use std::common_type_t to avoid this, hence optimizing the functions. A requirement of this patch is to change the overflow behavior of alignTo to only overflow when the result isn't representable in the return type.
1 parent 5cc1287 commit 5627794

File tree

2 files changed

+138
-69
lines changed

2 files changed

+138
-69
lines changed

llvm/include/llvm/Support/MathExtras.h

Lines changed: 122 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@
2323
#include <type_traits>
2424

2525
namespace llvm {
26+
/// Some template parameter helpers to optimize for bitwidth, for functions that
27+
/// take multiple arguments.
28+
29+
// We can't verify signedness, since callers rely on implicit coercions to
30+
// signed/unsigned.
31+
template <typename T, typename U>
32+
using enableif_int =
33+
std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U>>;
34+
35+
// Use std::common_type_t to widen only up to the widest argument.
36+
template <typename T, typename U, typename = enableif_int<T, U>>
37+
using common_uint =
38+
std::common_type_t<std::make_unsigned_t<T>, std::make_unsigned_t<U>>;
39+
template <typename T, typename U, typename = enableif_int<T, U>>
40+
using common_sint =
41+
std::common_type_t<std::make_signed_t<T>, std::make_signed_t<U>>;
2642

2743
/// Mathematical constants.
2844
namespace numbers {
@@ -346,7 +362,8 @@ inline unsigned Log2_64_Ceil(uint64_t Value) {
346362

347363
/// A and B are either alignments or offsets. Return the minimum alignment that
348364
/// may be assumed after adding the two together.
349-
constexpr uint64_t MinAlign(uint64_t A, uint64_t B) {
365+
template <typename U, typename V, typename T = common_uint<U, V>>
366+
constexpr T MinAlign(U A, V B) {
350367
// The largest power of 2 that divides both A and B.
351368
//
352369
// Replace "-Value" by "1+~Value" in the following commented code to avoid
@@ -355,6 +372,11 @@ constexpr uint64_t MinAlign(uint64_t A, uint64_t B) {
355372
return (A | B) & (1 + ~(A | B));
356373
}
357374

375+
/// Fallback when arguments aren't integral.
376+
constexpr uint64_t MinAlign(uint64_t A, uint64_t B) {
377+
return (A | B) & (1 + ~(A | B));
378+
}
379+
358380
/// Returns the next power of two (in 64-bits) that is strictly greater than A.
359381
/// Returns zero on overflow.
360382
constexpr uint64_t NextPowerOf2(uint64_t A) {
@@ -375,60 +397,17 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
375397
return UINT64_C(1) << Log2_64_Ceil(A);
376398
}
377399

378-
/// Returns the next integer (mod 2**64) that is greater than or equal to
379-
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
380-
///
381-
/// Examples:
382-
/// \code
383-
/// alignTo(5, 8) = 8
384-
/// alignTo(17, 8) = 24
385-
/// alignTo(~0LL, 8) = 0
386-
/// alignTo(321, 255) = 510
387-
/// \endcode
388-
///
389-
/// May overflow.
390-
inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
391-
assert(Align != 0u && "Align can't be 0.");
392-
return (Value + Align - 1) / Align * Align;
393-
}
394-
395-
inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
396-
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
397-
"Align must be a power of 2");
398-
// Replace unary minus to avoid compilation error on Windows:
399-
// "unary minus operator applied to unsigned type, result still unsigned"
400-
uint64_t negAlign = (~Align) + 1;
401-
return (Value + Align - 1) & negAlign;
402-
}
403-
404-
/// If non-zero \p Skew is specified, the return value will be a minimal integer
405-
/// that is greater than or equal to \p Size and equal to \p A * N + \p Skew for
406-
/// some integer N. If \p Skew is larger than \p A, its value is adjusted to '\p
407-
/// Skew mod \p A'. \p Align must be non-zero.
408-
///
409-
/// Examples:
410-
/// \code
411-
/// alignTo(5, 8, 7) = 7
412-
/// alignTo(17, 8, 1) = 17
413-
/// alignTo(~0LL, 8, 3) = 3
414-
/// alignTo(321, 255, 42) = 552
415-
/// \endcode
416-
inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
417-
assert(Align != 0u && "Align can't be 0.");
418-
Skew %= Align;
419-
return alignTo(Value - Skew, Align) + Skew;
420-
}
421-
422-
/// Returns the next integer (mod 2**64) that is greater than or equal to
423-
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
424-
template <uint64_t Align> constexpr uint64_t alignTo(uint64_t Value) {
425-
static_assert(Align != 0u, "Align must be non-zero");
426-
return (Value + Align - 1) / Align * Align;
427-
}
428-
429400
/// Returns the integer ceil(Numerator / Denominator). Unsigned version.
430401
/// Guaranteed to never overflow.
431-
inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
402+
template <typename U, typename V, typename T = common_uint<U, V>>
403+
constexpr T divideCeil(U Numerator, V Denominator) {
404+
assert(Denominator && "Division by zero");
405+
T Bias = (Numerator != 0);
406+
return (Numerator - Bias) / Denominator + Bias;
407+
}
408+
409+
/// Fallback when arguments aren't integral.
410+
constexpr uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
432411
assert(Denominator && "Division by zero");
433412
uint64_t Bias = (Numerator != 0);
434413
return (Numerator - Bias) / Denominator + Bias;
@@ -437,12 +416,13 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
437416
/// Returns the integer ceil(Numerator / Denominator). Signed version.
438417
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
439418
/// is -1.
440-
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
419+
template <typename U, typename V, typename T = common_sint<U, V>>
420+
constexpr T divideCeilSigned(U Numerator, V Denominator) {
441421
assert(Denominator && "Division by zero");
442422
if (!Numerator)
443423
return 0;
444424
// C's integer division rounds towards 0.
445-
int64_t Bias = (Denominator >= 0 ? 1 : -1);
425+
T Bias = Denominator >= 0 ? 1 : -1;
446426
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
447427
return SameSign ? (Numerator - Bias) / Denominator + 1
448428
: Numerator / Denominator;
@@ -451,36 +431,111 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
451431
/// Returns the integer floor(Numerator / Denominator). Signed version.
452432
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
453433
/// is -1.
454-
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
434+
template <typename U, typename V, typename T = common_sint<U, V>>
435+
constexpr T divideFloorSigned(U Numerator, V Denominator) {
455436
assert(Denominator && "Division by zero");
456437
if (!Numerator)
457438
return 0;
458439
// C's integer division rounds towards 0.
459-
int64_t Bias = Denominator >= 0 ? -1 : 1;
440+
T Bias = Denominator >= 0 ? -1 : 1;
460441
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
461442
return SameSign ? Numerator / Denominator
462443
: (Numerator - Bias) / Denominator - 1;
463444
}
464445

465446
/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
466447
/// always non-negative.
467-
inline int64_t mod(int64_t Numerator, int64_t Denominator) {
448+
template <typename U, typename V, typename T = common_sint<U, V>>
449+
constexpr T mod(U Numerator, V Denominator) {
468450
assert(Denominator >= 1 && "Mod by non-positive number");
469-
int64_t Mod = Numerator % Denominator;
451+
T Mod = Numerator % Denominator;
470452
return Mod < 0 ? Mod + Denominator : Mod;
471453
}
472454

473455
/// Returns (Numerator / Denominator) rounded by round-half-up. Guaranteed to
474456
/// never overflow.
475-
inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
457+
template <typename U, typename V, typename T = common_uint<U, V>>
458+
constexpr T divideNearest(U Numerator, V Denominator) {
476459
assert(Denominator && "Division by zero");
477-
uint64_t Mod = Numerator % Denominator;
478-
return (Numerator / Denominator) + (Mod > (Denominator - 1) / 2);
460+
T Mod = Numerator % Denominator;
461+
return (Numerator / Denominator) +
462+
(Mod > (static_cast<T>(Denominator) - 1) / 2);
463+
}
464+
465+
/// Returns the next integer (mod 2**nbits) that is greater than or equal to
466+
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
467+
///
468+
/// Examples:
469+
/// \code
470+
/// alignTo(5, 8) = 8
471+
/// alignTo(17, 8) = 24
472+
/// alignTo(~0LL, 8) = 0
473+
/// alignTo(321, 255) = 510
474+
/// \endcode
475+
///
476+
/// Will overflow only if result is not representable in T.
477+
template <typename U, typename V, typename T = common_uint<U, V>>
478+
constexpr T alignTo(U Value, V Align) {
479+
assert(Align != 0u && "Align can't be 0.");
480+
T CeilDiv = divideCeil(Value, Align);
481+
return CeilDiv * Align;
482+
}
483+
484+
/// Fallback when arguments aren't integral.
485+
constexpr uint64_t alignTo(uint64_t Value, uint64_t Align) {
486+
assert(Align != 0u && "Align can't be 0.");
487+
uint64_t CeilDiv = divideCeil(Value, Align);
488+
return CeilDiv * Align;
489+
}
490+
491+
constexpr uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
492+
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
493+
"Align must be a power of 2");
494+
// Replace unary minus to avoid compilation error on Windows:
495+
// "unary minus operator applied to unsigned type, result still unsigned"
496+
uint64_t NegAlign = (~Align) + 1;
497+
return (Value + Align - 1) & NegAlign;
498+
}
499+
500+
/// If non-zero \p Skew is specified, the return value will be a minimal integer
501+
/// that is greater than or equal to \p Size and equal to \p A * N + \p Skew for
502+
/// some integer N. If \p Skew is larger than \p A, its value is adjusted to '\p
503+
/// Skew mod \p A'. \p Align must be non-zero.
504+
///
505+
/// Examples:
506+
/// \code
507+
/// alignTo(5, 8, 7) = 7
508+
/// alignTo(17, 8, 1) = 17
509+
/// alignTo(~0LL, 8, 3) = 3
510+
/// alignTo(321, 255, 42) = 552
511+
/// \endcode
512+
///
513+
/// May overflow.
514+
template <typename U, typename V, typename W,
515+
typename T = common_uint<common_uint<U, V>, W>>
516+
constexpr T alignTo(U Value, V Align, W Skew) {
517+
assert(Align != 0u && "Align can't be 0.");
518+
Skew %= Align;
519+
return alignTo(Value - Skew, Align) + Skew;
479520
}
480521

481-
/// Returns the largest uint64_t less than or equal to \p Value and is
482-
/// \p Skew mod \p Align. \p Align must be non-zero
483-
inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
522+
/// Returns the next integer (mod 2**nbits) that is greater than or equal to
523+
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
524+
///
525+
/// Will overflow only if result is not representable in T.
526+
template <auto Align, typename V, typename T = common_uint<decltype(Align), V>>
527+
constexpr T alignTo(V Value) {
528+
static_assert(Align != 0u, "Align must be non-zero");
529+
T CeilDiv = divideCeil(Value, Align);
530+
return CeilDiv * Align;
531+
}
532+
533+
/// Returns the largest unsigned integer less than or equal to \p Value and is
534+
/// \p Skew mod \p Align. \p Align must be non-zero. Guaranteed to never
535+
/// overflow.
536+
template <typename U, typename V, typename W = uint8_t,
537+
typename T = common_uint<common_uint<U, V>, W>>
538+
constexpr T alignDown(U Value, V Align, W Skew = 0) {
484539
assert(Align != 0u && "Align can't be 0.");
485540
Skew %= Align;
486541
return (Value - Skew) / Align * Align + Skew;
@@ -524,8 +579,8 @@ inline int64_t SignExtend64(uint64_t X, unsigned B) {
524579

525580
/// Subtract two unsigned integers, X and Y, of type T and return the absolute
526581
/// value of the result.
527-
template <typename T>
528-
std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
582+
template <typename U, typename V, typename T = common_uint<U, V>>
583+
constexpr T AbsoluteDifference(U X, V Y) {
529584
return X > Y ? (X - Y) : (Y - X);
530585
}
531586

llvm/unittests/Support/MathExtrasTest.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,35 @@ TEST(MathExtras, AlignTo) {
189189
EXPECT_EQ(8u, alignTo(5, 8));
190190
EXPECT_EQ(24u, alignTo(17, 8));
191191
EXPECT_EQ(0u, alignTo(~0LL, 8));
192-
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
193-
alignTo(std::numeric_limits<uint32_t>::max(), 2));
192+
EXPECT_EQ(8u, alignTo(5ULL, 8ULL));
193+
194+
EXPECT_EQ(8u, alignTo<8>(5));
195+
EXPECT_EQ(24u, alignTo<8>(17));
196+
EXPECT_EQ(0u, alignTo<8>(~0LL));
197+
EXPECT_EQ(254u,
198+
alignTo<static_cast<uint8_t>(127)>(static_cast<uint8_t>(200)));
194199

195200
EXPECT_EQ(7u, alignTo(5, 8, 7));
196201
EXPECT_EQ(17u, alignTo(17, 8, 1));
197202
EXPECT_EQ(3u, alignTo(~0LL, 8, 3));
198203
EXPECT_EQ(552u, alignTo(321, 255, 42));
199204
EXPECT_EQ(std::numeric_limits<uint32_t>::max(),
200205
alignTo(std::numeric_limits<uint32_t>::max(), 2, 1));
206+
207+
// Overflow.
208+
EXPECT_EQ(0u, alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(128)));
209+
EXPECT_EQ(0u, alignTo<static_cast<uint8_t>(128)>(static_cast<uint8_t>(200)));
210+
EXPECT_EQ(0u, alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(128),
211+
static_cast<uint8_t>(0)));
212+
EXPECT_EQ(0u, alignTo(std::numeric_limits<uint32_t>::max(), 2));
201213
}
202214

203215
TEST(MathExtras, AlignToPowerOf2) {
216+
EXPECT_EQ(0u, alignToPowerOf2(0u, 8));
204217
EXPECT_EQ(8u, alignToPowerOf2(5, 8));
205218
EXPECT_EQ(24u, alignToPowerOf2(17, 8));
206219
EXPECT_EQ(0u, alignToPowerOf2(~0LL, 8));
220+
EXPECT_EQ(240u, alignToPowerOf2(240, 16));
207221
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
208222
alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2));
209223
}

0 commit comments

Comments
 (0)