Skip to content

Commit 645fb04

Browse files
authored
[Reassociate] Use uint64_t for repeat count (#94232)
This patch relands #91469 and uses `uint64_t` for repeat count to avoid a miscompilation caused by overflow #91469 (comment).
1 parent d9507a3 commit 645fb04

File tree

2 files changed

+43
-122
lines changed

2 files changed

+43
-122
lines changed

llvm/lib/Transforms/Scalar/Reassociate.cpp

+12-108
Original file line numberDiff line numberDiff line change
@@ -302,98 +302,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
302302
return Res;
303303
}
304304

305-
/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael
306-
/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for
307-
/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
308-
/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
309-
/// even x in Bitwidth-bit arithmetic.
310-
static unsigned CarmichaelShift(unsigned Bitwidth) {
311-
if (Bitwidth < 3)
312-
return Bitwidth - 1;
313-
return Bitwidth - 2;
314-
}
315-
316-
/// Add the extra weight 'RHS' to the existing weight 'LHS',
317-
/// reducing the combined weight using any special properties of the operation.
318-
/// The existing weight LHS represents the computation X op X op ... op X where
319-
/// X occurs LHS times. The combined weight represents X op X op ... op X with
320-
/// X occurring LHS + RHS times. If op is "Xor" for example then the combined
321-
/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
322-
/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
323-
static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
324-
// If we were working with infinite precision arithmetic then the combined
325-
// weight would be LHS + RHS. But we are using finite precision arithmetic,
326-
// and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
327-
// for nilpotent operations and addition, but not for idempotent operations
328-
// and multiplication), so it is important to correctly reduce the combined
329-
// weight back into range if wrapping would be wrong.
330-
331-
// If RHS is zero then the weight didn't change.
332-
if (RHS.isMinValue())
333-
return;
334-
// If LHS is zero then the combined weight is RHS.
335-
if (LHS.isMinValue()) {
336-
LHS = RHS;
337-
return;
338-
}
339-
// From this point on we know that neither LHS nor RHS is zero.
340-
341-
if (Instruction::isIdempotent(Opcode)) {
342-
// Idempotent means X op X === X, so any non-zero weight is equivalent to a
343-
// weight of 1. Keeping weights at zero or one also means that wrapping is
344-
// not a problem.
345-
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
346-
return; // Return a weight of 1.
347-
}
348-
if (Instruction::isNilpotent(Opcode)) {
349-
// Nilpotent means X op X === 0, so reduce weights modulo 2.
350-
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
351-
LHS = 0; // 1 + 1 === 0 modulo 2.
352-
return;
353-
}
354-
if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) {
355-
// TODO: Reduce the weight by exploiting nsw/nuw?
356-
LHS += RHS;
357-
return;
358-
}
359-
360-
assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) &&
361-
"Unknown associative operation!");
362-
unsigned Bitwidth = LHS.getBitWidth();
363-
// If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
364-
// can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth
365-
// bit number x, since either x is odd in which case x^CM = 1, or x is even in
366-
// which case both x^W and x^(W - CM) are zero. By subtracting off multiples
367-
// of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
368-
// which by a happy accident means that they can always be represented using
369-
// Bitwidth bits.
370-
// TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than
371-
// the Carmichael number).
372-
if (Bitwidth > 3) {
373-
/// CM - The value of Carmichael's lambda function.
374-
APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
375-
// Any weight W >= Threshold can be replaced with W - CM.
376-
APInt Threshold = CM + Bitwidth;
377-
assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
378-
// For Bitwidth 4 or more the following sum does not overflow.
379-
LHS += RHS;
380-
while (LHS.uge(Threshold))
381-
LHS -= CM;
382-
} else {
383-
// To avoid problems with overflow do everything the same as above but using
384-
// a larger type.
385-
unsigned CM = 1U << CarmichaelShift(Bitwidth);
386-
unsigned Threshold = CM + Bitwidth;
387-
assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
388-
"Weights not reduced!");
389-
unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
390-
while (Total >= Threshold)
391-
Total -= CM;
392-
LHS = Total;
393-
}
394-
}
395-
396-
using RepeatedValue = std::pair<Value*, APInt>;
305+
using RepeatedValue = std::pair<Value *, uint64_t>;
397306

398307
/// Given an associative binary expression, return the leaf
399308
/// nodes in Ops along with their weights (how many times the leaf occurs). The
@@ -475,7 +384,6 @@ static bool LinearizeExprTree(Instruction *I,
475384
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
476385
"Expected a UnaryOperator or BinaryOperator!");
477386
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
478-
unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits();
479387
unsigned Opcode = I->getOpcode();
480388
assert(I->isAssociative() && I->isCommutative() &&
481389
"Expected an associative and commutative operation!");
@@ -490,8 +398,8 @@ static bool LinearizeExprTree(Instruction *I,
490398
// with their weights, representing a certain number of paths to the operator.
491399
// If an operator occurs in the worklist multiple times then we found multiple
492400
// ways to get to it.
493-
SmallVector<std::pair<Instruction*, APInt>, 8> Worklist; // (Op, Weight)
494-
Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1)));
401+
SmallVector<std::pair<Instruction *, uint64_t>, 8> Worklist; // (Op, Weight)
402+
Worklist.push_back(std::make_pair(I, 1));
495403
bool Changed = false;
496404

497405
// Leaves of the expression are values that either aren't the right kind of
@@ -509,7 +417,7 @@ static bool LinearizeExprTree(Instruction *I,
509417

510418
// Leaves - Keeps track of the set of putative leaves as well as the number of
511419
// paths to each leaf seen so far.
512-
using LeafMap = DenseMap<Value *, APInt>;
420+
using LeafMap = DenseMap<Value *, uint64_t>;
513421
LeafMap Leaves; // Leaf -> Total weight so far.
514422
SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order.
515423
const DataLayout DL = I->getModule()->getDataLayout();
@@ -518,8 +426,8 @@ static bool LinearizeExprTree(Instruction *I,
518426
SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme.
519427
#endif
520428
while (!Worklist.empty()) {
521-
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
522-
I = P.first; // We examine the operands of this binary operator.
429+
// We examine the operands of this binary operator.
430+
auto [I, Weight] = Worklist.pop_back_val();
523431

524432
if (isa<OverflowingBinaryOperator>(I)) {
525433
Flags.HasNUW &= I->hasNoUnsignedWrap();
@@ -528,7 +436,6 @@ static bool LinearizeExprTree(Instruction *I,
528436

529437
for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
530438
Value *Op = I->getOperand(OpIdx);
531-
APInt Weight = P.second; // Number of paths to this operand.
532439
LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n");
533440
assert(!Op->use_empty() && "No uses, so how did we get to it?!");
534441

@@ -562,7 +469,8 @@ static bool LinearizeExprTree(Instruction *I,
562469
"In leaf map but not visited!");
563470

564471
// Update the number of paths to the leaf.
565-
IncorporateWeight(It->second, Weight, Opcode);
472+
It->second += Weight;
473+
assert(It->second >= Weight && "Weight overflows");
566474

567475
// If we still have uses that are not accounted for by the expression
568476
// then it is not safe to modify the value.
@@ -625,10 +533,7 @@ static bool LinearizeExprTree(Instruction *I,
625533
// Node initially thought to be a leaf wasn't.
626534
continue;
627535
assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!");
628-
APInt Weight = It->second;
629-
if (Weight.isMinValue())
630-
// Leaf already output or weight reduction eliminated it.
631-
continue;
536+
uint64_t Weight = It->second;
632537
// Ensure the leaf is only output once.
633538
It->second = 0;
634539
Ops.push_back(std::make_pair(V, Weight));
@@ -642,7 +547,7 @@ static bool LinearizeExprTree(Instruction *I,
642547
if (Ops.empty()) {
643548
Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType());
644549
assert(Identity && "Associative operation without identity!");
645-
Ops.emplace_back(Identity, APInt(Bitwidth, 1));
550+
Ops.emplace_back(Identity, 1);
646551
}
647552

648553
return Changed;
@@ -1188,8 +1093,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
11881093
Factors.reserve(Tree.size());
11891094
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
11901095
RepeatedValue E = Tree[i];
1191-
Factors.append(E.second.getZExtValue(),
1192-
ValueEntry(getRank(E.first), E.first));
1096+
Factors.append(E.second, ValueEntry(getRank(E.first), E.first));
11931097
}
11941098

11951099
bool FoundFactor = false;
@@ -2368,7 +2272,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
23682272
SmallVector<ValueEntry, 8> Ops;
23692273
Ops.reserve(Tree.size());
23702274
for (const RepeatedValue &E : Tree)
2371-
Ops.append(E.second.getZExtValue(), ValueEntry(getRank(E.first), E.first));
2275+
Ops.append(E.second, ValueEntry(getRank(E.first), E.first));
23722276

23732277
LLVM_DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n');
23742278

llvm/test/Transforms/Reassociate/repeats.ll

+31-14
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ define i3 @foo3x5(i3 %x) {
6060
; CHECK-SAME: i3 [[X:%.*]]) {
6161
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]]
6262
; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]]
63-
; CHECK-NEXT: ret i3 [[TMP4]]
63+
; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]]
64+
; CHECK-NEXT: ret i3 [[TMP5]]
6465
;
6566
%tmp1 = mul i3 %x, %x
6667
%tmp2 = mul i3 %tmp1, %x
@@ -74,7 +75,8 @@ define i3 @foo3x5_nsw(i3 %x) {
7475
; CHECK-LABEL: define i3 @foo3x5_nsw(
7576
; CHECK-SAME: i3 [[X:%.*]]) {
7677
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]]
77-
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]]
78+
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]]
79+
; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]]
7880
; CHECK-NEXT: ret i3 [[TMP4]]
7981
;
8082
%tmp1 = mul i3 %x, %x
@@ -89,7 +91,8 @@ define i3 @foo3x6(i3 %x) {
8991
; CHECK-LABEL: define i3 @foo3x6(
9092
; CHECK-SAME: i3 [[X:%.*]]) {
9193
; CHECK-NEXT: [[TMP1:%.*]] = mul i3 [[X]], [[X]]
92-
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]]
94+
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]]
95+
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]]
9396
; CHECK-NEXT: ret i3 [[TMP2]]
9497
;
9598
%tmp1 = mul i3 %x, %x
@@ -106,7 +109,9 @@ define i3 @foo3x7(i3 %x) {
106109
; CHECK-SAME: i3 [[X:%.*]]) {
107110
; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[X]], [[X]]
108111
; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]]
109-
; CHECK-NEXT: ret i3 [[TMP6]]
112+
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP6]], [[X]]
113+
; CHECK-NEXT: [[TMP7:%.*]] = mul i3 [[TMP3]], [[TMP6]]
114+
; CHECK-NEXT: ret i3 [[TMP7]]
110115
;
111116
%tmp1 = mul i3 %x, %x
112117
%tmp2 = mul i3 %tmp1, %x
@@ -123,7 +128,8 @@ define i4 @foo4x8(i4 %x) {
123128
; CHECK-SAME: i4 [[X:%.*]]) {
124129
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
125130
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
126-
; CHECK-NEXT: ret i4 [[TMP4]]
131+
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
132+
; CHECK-NEXT: ret i4 [[TMP3]]
127133
;
128134
%tmp1 = mul i4 %x, %x
129135
%tmp2 = mul i4 %tmp1, %x
@@ -140,8 +146,9 @@ define i4 @foo4x9(i4 %x) {
140146
; CHECK-LABEL: define i4 @foo4x9(
141147
; CHECK-SAME: i4 [[X:%.*]]) {
142148
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
143-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
144-
; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]]
149+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
150+
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
151+
; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]]
145152
; CHECK-NEXT: ret i4 [[TMP8]]
146153
;
147154
%tmp1 = mul i4 %x, %x
@@ -160,7 +167,8 @@ define i4 @foo4x10(i4 %x) {
160167
; CHECK-LABEL: define i4 @foo4x10(
161168
; CHECK-SAME: i4 [[X:%.*]]) {
162169
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
163-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
170+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
171+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
164172
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
165173
; CHECK-NEXT: ret i4 [[TMP3]]
166174
;
@@ -181,7 +189,8 @@ define i4 @foo4x11(i4 %x) {
181189
; CHECK-LABEL: define i4 @foo4x11(
182190
; CHECK-SAME: i4 [[X:%.*]]) {
183191
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
184-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
192+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
193+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
185194
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
186195
; CHECK-NEXT: [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]]
187196
; CHECK-NEXT: ret i4 [[TMP10]]
@@ -204,7 +213,9 @@ define i4 @foo4x12(i4 %x) {
204213
; CHECK-LABEL: define i4 @foo4x12(
205214
; CHECK-SAME: i4 [[X:%.*]]) {
206215
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
207-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
216+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
217+
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
218+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]]
208219
; CHECK-NEXT: ret i4 [[TMP2]]
209220
;
210221
%tmp1 = mul i4 %x, %x
@@ -227,7 +238,9 @@ define i4 @foo4x13(i4 %x) {
227238
; CHECK-SAME: i4 [[X:%.*]]) {
228239
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
229240
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
230-
; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]]
241+
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
242+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
243+
; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]]
231244
; CHECK-NEXT: ret i4 [[TMP12]]
232245
;
233246
%tmp1 = mul i4 %x, %x
@@ -252,7 +265,9 @@ define i4 @foo4x14(i4 %x) {
252265
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
253266
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
254267
; CHECK-NEXT: [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]]
255-
; CHECK-NEXT: ret i4 [[TMP7]]
268+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP7]], [[X]]
269+
; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]]
270+
; CHECK-NEXT: ret i4 [[TMP5]]
256271
;
257272
%tmp1 = mul i4 %x, %x
258273
%tmp2 = mul i4 %tmp1, %x
@@ -276,8 +291,10 @@ define i4 @foo4x15(i4 %x) {
276291
; CHECK-SAME: i4 [[X:%.*]]) {
277292
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
278293
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
279-
; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]]
280-
; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]]
294+
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP6]], [[TMP6]]
295+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
296+
; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[X]]
297+
; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP4]]
281298
; CHECK-NEXT: ret i4 [[TMP14]]
282299
;
283300
%tmp1 = mul i4 %x, %x

0 commit comments

Comments
 (0)