Skip to content

Commit 22b63b9

Browse files
authored
Revert "[Reassociate] Drop weight reduction to fix issue 91417 (#91469)" (#94210)
Reverts 3bcccb6 and 9a28272 because #91469 causes a miscompilation #91469 (comment).
1 parent 12949c9 commit 22b63b9

File tree

3 files changed

+115
-40
lines changed

3 files changed

+115
-40
lines changed

llvm/lib/Transforms/Scalar/Reassociate.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,97 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
302302
return Res;
303303
}
304304

305+
/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael
306+
/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for
307+
/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
308+
/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
309+
/// even x in Bitwidth-bit arithmetic.
310+
static unsigned CarmichaelShift(unsigned Bitwidth) {
311+
if (Bitwidth < 3)
312+
return Bitwidth - 1;
313+
return Bitwidth - 2;
314+
}
315+
316+
/// Add the extra weight 'RHS' to the existing weight 'LHS',
317+
/// reducing the combined weight using any special properties of the operation.
318+
/// The existing weight LHS represents the computation X op X op ... op X where
319+
/// X occurs LHS times. The combined weight represents X op X op ... op X with
320+
/// X occurring LHS + RHS times. If op is "Xor" for example then the combined
321+
/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
322+
/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
323+
static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
324+
// If we were working with infinite precision arithmetic then the combined
325+
// weight would be LHS + RHS. But we are using finite precision arithmetic,
326+
// and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
327+
// for nilpotent operations and addition, but not for idempotent operations
328+
// and multiplication), so it is important to correctly reduce the combined
329+
// weight back into range if wrapping would be wrong.
330+
331+
// If RHS is zero then the weight didn't change.
332+
if (RHS.isMinValue())
333+
return;
334+
// If LHS is zero then the combined weight is RHS.
335+
if (LHS.isMinValue()) {
336+
LHS = RHS;
337+
return;
338+
}
339+
// From this point on we know that neither LHS nor RHS is zero.
340+
341+
if (Instruction::isIdempotent(Opcode)) {
342+
// Idempotent means X op X === X, so any non-zero weight is equivalent to a
343+
// weight of 1. Keeping weights at zero or one also means that wrapping is
344+
// not a problem.
345+
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
346+
return; // Return a weight of 1.
347+
}
348+
if (Instruction::isNilpotent(Opcode)) {
349+
// Nilpotent means X op X === 0, so reduce weights modulo 2.
350+
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
351+
LHS = 0; // 1 + 1 === 0 modulo 2.
352+
return;
353+
}
354+
if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) {
355+
// TODO: Reduce the weight by exploiting nsw/nuw?
356+
LHS += RHS;
357+
return;
358+
}
359+
360+
assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) &&
361+
"Unknown associative operation!");
362+
unsigned Bitwidth = LHS.getBitWidth();
363+
// If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
364+
// can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth
365+
// bit number x, since either x is odd in which case x^CM = 1, or x is even in
366+
// which case both x^W and x^(W - CM) are zero. By subtracting off multiples
367+
// of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
368+
// which by a happy accident means that they can always be represented using
369+
// Bitwidth bits.
370+
// TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than
371+
// the Carmichael number).
372+
if (Bitwidth > 3) {
373+
/// CM - The value of Carmichael's lambda function.
374+
APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
375+
// Any weight W >= Threshold can be replaced with W - CM.
376+
APInt Threshold = CM + Bitwidth;
377+
assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
378+
// For Bitwidth 4 or more the following sum does not overflow.
379+
LHS += RHS;
380+
while (LHS.uge(Threshold))
381+
LHS -= CM;
382+
} else {
383+
// To avoid problems with overflow do everything the same as above but using
384+
// a larger type.
385+
unsigned CM = 1U << CarmichaelShift(Bitwidth);
386+
unsigned Threshold = CM + Bitwidth;
387+
assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
388+
"Weights not reduced!");
389+
unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
390+
while (Total >= Threshold)
391+
Total -= CM;
392+
LHS = Total;
393+
}
394+
}
395+
305396
using RepeatedValue = std::pair<Value*, APInt>;
306397

307398
/// Given an associative binary expression, return the leaf
@@ -471,7 +562,7 @@ static bool LinearizeExprTree(Instruction *I,
471562
"In leaf map but not visited!");
472563

473564
// Update the number of paths to the leaf.
474-
It->second += Weight;
565+
IncorporateWeight(It->second, Weight, Opcode);
475566

476567
// If we still have uses that are not accounted for by the expression
477568
// then it is not safe to modify the value.

llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,20 @@ define <8 x i1> @vector2(<8 x i1> %a, <8 x i1> %b0, <8 x i1> %b1, <8 x i1> %b2,
5656
; CHECK-NEXT: [[OR5:%.*]] = or <8 x i1> [[B5]], [[A]]
5757
; CHECK-NEXT: [[OR6:%.*]] = or <8 x i1> [[B6]], [[A]]
5858
; CHECK-NEXT: [[OR7:%.*]] = or <8 x i1> [[B7]], [[A]]
59-
; CHECK-NEXT: [[XOR0:%.*]] = xor <8 x i1> [[OR1]], [[OR0]]
60-
; CHECK-NEXT: [[XOR2:%.*]] = xor <8 x i1> [[XOR0]], [[OR2]]
61-
; CHECK-NEXT: [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR3]]
62-
; CHECK-NEXT: [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR4]]
63-
; CHECK-NEXT: [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR5]]
64-
; CHECK-NEXT: [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR6]]
65-
; CHECK-NEXT: [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR7]]
59+
; CHECK-NEXT: [[XOR2:%.*]] = xor <8 x i1> [[OR1]], [[OR0]]
60+
; CHECK-NEXT: [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR2]]
61+
; CHECK-NEXT: [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR3]]
62+
; CHECK-NEXT: [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR4]]
63+
; CHECK-NEXT: [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR5]]
64+
; CHECK-NEXT: [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR6]]
65+
; CHECK-NEXT: [[XOR7:%.*]] = xor <8 x i1> [[XOR6]], [[OR7]]
6666
; CHECK-NEXT: [[OR4560:%.*]] = or <8 x i1> [[OR045]], [[XOR2]]
6767
; CHECK-NEXT: [[OR023:%.*]] = or <8 x i1> [[OR4560]], [[XOR3]]
6868
; CHECK-NEXT: [[OR001:%.*]] = or <8 x i1> [[OR023]], [[XOR4]]
6969
; CHECK-NEXT: [[OR0123:%.*]] = or <8 x i1> [[OR001]], [[XOR5]]
7070
; CHECK-NEXT: [[OR01234567:%.*]] = or <8 x i1> [[OR0123]], [[XOR6]]
71-
; CHECK-NEXT: ret <8 x i1> [[OR01234567]]
71+
; CHECK-NEXT: [[OR1234567:%.*]] = or <8 x i1> [[OR01234567]], [[XOR7]]
72+
; CHECK-NEXT: ret <8 x i1> [[OR1234567]]
7273
;
7374
%or0 = or <8 x i1> %b0, %a
7475
%or1 = or <8 x i1> %b1, %a

llvm/test/Transforms/Reassociate/repeats.ll

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ define i8 @nilpotent(i8 %x) {
1515
define i2 @idempotent(i2 %x) {
1616
; CHECK-LABEL: define i2 @idempotent(
1717
; CHECK-SAME: i2 [[X:%.*]]) {
18-
; CHECK-NEXT: ret i2 -1
18+
; CHECK-NEXT: ret i2 [[X]]
1919
;
2020
%tmp1 = and i2 %x, %x
2121
%tmp2 = and i2 %tmp1, %x
@@ -60,8 +60,7 @@ define i3 @foo3x5(i3 %x) {
6060
; CHECK-SAME: i3 [[X:%.*]]) {
6161
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]]
6262
; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]]
63-
; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]]
64-
; CHECK-NEXT: ret i3 [[TMP5]]
63+
; CHECK-NEXT: ret i3 [[TMP4]]
6564
;
6665
%tmp1 = mul i3 %x, %x
6766
%tmp2 = mul i3 %tmp1, %x
@@ -75,8 +74,7 @@ define i3 @foo3x5_nsw(i3 %x) {
7574
; CHECK-LABEL: define i3 @foo3x5_nsw(
7675
; CHECK-SAME: i3 [[X:%.*]]) {
7776
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]]
78-
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]]
79-
; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]]
77+
; CHECK-NEXT: [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]]
8078
; CHECK-NEXT: ret i3 [[TMP4]]
8179
;
8280
%tmp1 = mul i3 %x, %x
@@ -91,8 +89,7 @@ define i3 @foo3x6(i3 %x) {
9189
; CHECK-LABEL: define i3 @foo3x6(
9290
; CHECK-SAME: i3 [[X:%.*]]) {
9391
; CHECK-NEXT: [[TMP1:%.*]] = mul i3 [[X]], [[X]]
94-
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]]
95-
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]]
92+
; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]]
9693
; CHECK-NEXT: ret i3 [[TMP2]]
9794
;
9895
%tmp1 = mul i3 %x, %x
@@ -108,9 +105,7 @@ define i3 @foo3x7(i3 %x) {
108105
; CHECK-LABEL: define i3 @foo3x7(
109106
; CHECK-SAME: i3 [[X:%.*]]) {
110107
; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[X]], [[X]]
111-
; CHECK-NEXT: [[TMP7:%.*]] = mul i3 [[TMP5]], [[X]]
112-
; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP7]], [[X]]
113-
; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP3]], [[TMP7]]
108+
; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]]
114109
; CHECK-NEXT: ret i3 [[TMP6]]
115110
;
116111
%tmp1 = mul i3 %x, %x
@@ -127,8 +122,7 @@ define i4 @foo4x8(i4 %x) {
127122
; CHECK-LABEL: define i4 @foo4x8(
128123
; CHECK-SAME: i4 [[X:%.*]]) {
129124
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
130-
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP1]], [[TMP1]]
131-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[TMP3]]
125+
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
132126
; CHECK-NEXT: ret i4 [[TMP4]]
133127
;
134128
%tmp1 = mul i4 %x, %x
@@ -146,9 +140,8 @@ define i4 @foo4x9(i4 %x) {
146140
; CHECK-LABEL: define i4 @foo4x9(
147141
; CHECK-SAME: i4 [[X:%.*]]) {
148142
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
149-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
150-
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
151-
; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]]
143+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
144+
; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]]
152145
; CHECK-NEXT: ret i4 [[TMP8]]
153146
;
154147
%tmp1 = mul i4 %x, %x
@@ -167,8 +160,7 @@ define i4 @foo4x10(i4 %x) {
167160
; CHECK-LABEL: define i4 @foo4x10(
168161
; CHECK-SAME: i4 [[X:%.*]]) {
169162
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
170-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
171-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
163+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
172164
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
173165
; CHECK-NEXT: ret i4 [[TMP3]]
174166
;
@@ -189,8 +181,7 @@ define i4 @foo4x11(i4 %x) {
189181
; CHECK-LABEL: define i4 @foo4x11(
190182
; CHECK-SAME: i4 [[X:%.*]]) {
191183
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
192-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
193-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
184+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
194185
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
195186
; CHECK-NEXT: [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]]
196187
; CHECK-NEXT: ret i4 [[TMP10]]
@@ -213,9 +204,7 @@ define i4 @foo4x12(i4 %x) {
213204
; CHECK-LABEL: define i4 @foo4x12(
214205
; CHECK-SAME: i4 [[X:%.*]]) {
215206
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
216-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
217-
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
218-
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]]
207+
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
219208
; CHECK-NEXT: ret i4 [[TMP2]]
220209
;
221210
%tmp1 = mul i4 %x, %x
@@ -238,9 +227,7 @@ define i4 @foo4x13(i4 %x) {
238227
; CHECK-SAME: i4 [[X:%.*]]) {
239228
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
240229
; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
241-
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
242-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
243-
; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]]
230+
; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]]
244231
; CHECK-NEXT: ret i4 [[TMP12]]
245232
;
246233
%tmp1 = mul i4 %x, %x
@@ -263,9 +250,7 @@ define i4 @foo4x14(i4 %x) {
263250
; CHECK-LABEL: define i4 @foo4x14(
264251
; CHECK-SAME: i4 [[X:%.*]]) {
265252
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
266-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
267-
; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]]
268-
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP5]], [[X]]
253+
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
269254
; CHECK-NEXT: [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]]
270255
; CHECK-NEXT: ret i4 [[TMP7]]
271256
;
@@ -290,9 +275,7 @@ define i4 @foo4x15(i4 %x) {
290275
; CHECK-LABEL: define i4 @foo4x15(
291276
; CHECK-SAME: i4 [[X:%.*]]) {
292277
; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]]
293-
; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
294-
; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
295-
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP3]], [[X]]
278+
; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
296279
; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]]
297280
; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]]
298281
; CHECK-NEXT: ret i4 [[TMP14]]

0 commit comments

Comments
 (0)