From 6b6a134f56a5e515c30528bc3a3bb0e9142d6ece Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 3 Jun 2024 23:12:58 +0800 Subject: [PATCH 1/2] [Reassociate] Use uint64_t for repeat count --- llvm/lib/Transforms/Scalar/Reassociate.cpp | 119 ++------------------ llvm/test/Transforms/Reassociate/repeats.ll | 45 +++++--- 2 files changed, 42 insertions(+), 122 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index c73d7c8d83bec..6cf097094ddd0 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -302,98 +302,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { return Res; } -/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael -/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for -/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic. -/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every -/// even x in Bitwidth-bit arithmetic. -static unsigned CarmichaelShift(unsigned Bitwidth) { - if (Bitwidth < 3) - return Bitwidth - 1; - return Bitwidth - 2; -} - -/// Add the extra weight 'RHS' to the existing weight 'LHS', -/// reducing the combined weight using any special properties of the operation. -/// The existing weight LHS represents the computation X op X op ... op X where -/// X occurs LHS times. The combined weight represents X op X op ... op X with -/// X occurring LHS + RHS times. If op is "Xor" for example then the combined -/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even; -/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second. -static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { - // If we were working with infinite precision arithmetic then the combined - // weight would be LHS + RHS. But we are using finite precision arithmetic, - // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct - // for nilpotent operations and addition, but not for idempotent operations - // and multiplication), so it is important to correctly reduce the combined - // weight back into range if wrapping would be wrong. - - // If RHS is zero then the weight didn't change. - if (RHS.isMinValue()) - return; - // If LHS is zero then the combined weight is RHS. - if (LHS.isMinValue()) { - LHS = RHS; - return; - } - // From this point on we know that neither LHS nor RHS is zero. - - if (Instruction::isIdempotent(Opcode)) { - // Idempotent means X op X === X, so any non-zero weight is equivalent to a - // weight of 1. Keeping weights at zero or one also means that wrapping is - // not a problem. - assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); - return; // Return a weight of 1. - } - if (Instruction::isNilpotent(Opcode)) { - // Nilpotent means X op X === 0, so reduce weights modulo 2. - assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); - LHS = 0; // 1 + 1 === 0 modulo 2. - return; - } - if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) { - // TODO: Reduce the weight by exploiting nsw/nuw? - LHS += RHS; - return; - } - - assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) && - "Unknown associative operation!"); - unsigned Bitwidth = LHS.getBitWidth(); - // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth - // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth - // bit number x, since either x is odd in which case x^CM = 1, or x is even in - // which case both x^W and x^(W - CM) are zero. By subtracting off multiples - // of CM like this weights can always be reduced to the range [0, CM+Bitwidth) - // which by a happy accident means that they can always be represented using - // Bitwidth bits. - // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than - // the Carmichael number). - if (Bitwidth > 3) { - /// CM - The value of Carmichael's lambda function. - APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth)); - // Any weight W >= Threshold can be replaced with W - CM. - APInt Threshold = CM + Bitwidth; - assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!"); - // For Bitwidth 4 or more the following sum does not overflow. - LHS += RHS; - while (LHS.uge(Threshold)) - LHS -= CM; - } else { - // To avoid problems with overflow do everything the same as above but using - // a larger type. - unsigned CM = 1U << CarmichaelShift(Bitwidth); - unsigned Threshold = CM + Bitwidth; - assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold && - "Weights not reduced!"); - unsigned Total = LHS.getZExtValue() + RHS.getZExtValue(); - while (Total >= Threshold) - Total -= CM; - LHS = Total; - } -} - -using RepeatedValue = std::pair; +using RepeatedValue = std::pair; /// Given an associative binary expression, return the leaf /// nodes in Ops along with their weights (how many times the leaf occurs). The @@ -475,7 +384,6 @@ static bool LinearizeExprTree(Instruction *I, assert((isa(I) || isa(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); - unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); assert(I->isAssociative() && I->isCommutative() && "Expected an associative and commutative operation!"); @@ -490,8 +398,8 @@ static bool LinearizeExprTree(Instruction *I, // with their weights, representing a certain number of paths to the operator. // If an operator occurs in the worklist multiple times then we found multiple // ways to get to it. - SmallVector, 8> Worklist; // (Op, Weight) - Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); + SmallVector, 8> Worklist; // (Op, Weight) + Worklist.push_back(std::make_pair(I, 1)); bool Changed = false; // Leaves of the expression are values that either aren't the right kind of @@ -509,7 +417,7 @@ static bool LinearizeExprTree(Instruction *I, // Leaves - Keeps track of the set of putative leaves as well as the number of // paths to each leaf seen so far. - using LeafMap = DenseMap; + using LeafMap = DenseMap; LeafMap Leaves; // Leaf -> Total weight so far. SmallVector LeafOrder; // Ensure deterministic leaf output order. const DataLayout DL = I->getModule()->getDataLayout(); @@ -518,8 +426,8 @@ static bool LinearizeExprTree(Instruction *I, SmallPtrSet Visited; // For checking the iteration scheme. #endif while (!Worklist.empty()) { - std::pair P = Worklist.pop_back_val(); - I = P.first; // We examine the operands of this binary operator. + // We examine the operands of this binary operator. + auto [I, Weight] = Worklist.pop_back_val(); if (isa(I)) { Flags.HasNUW &= I->hasNoUnsignedWrap(); @@ -528,7 +436,6 @@ static bool LinearizeExprTree(Instruction *I, for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); - APInt Weight = P.second; // Number of paths to this operand. LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); assert(!Op->use_empty() && "No uses, so how did we get to it?!"); @@ -562,7 +469,7 @@ static bool LinearizeExprTree(Instruction *I, "In leaf map but not visited!"); // Update the number of paths to the leaf. - IncorporateWeight(It->second, Weight, Opcode); + It->second += Weight; // If we still have uses that are not accounted for by the expression // then it is not safe to modify the value. @@ -625,10 +532,7 @@ static bool LinearizeExprTree(Instruction *I, // Node initially thought to be a leaf wasn't. continue; assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!"); - APInt Weight = It->second; - if (Weight.isMinValue()) - // Leaf already output or weight reduction eliminated it. - continue; + uint64_t Weight = It->second; // Ensure the leaf is only output once. It->second = 0; Ops.push_back(std::make_pair(V, Weight)); @@ -642,7 +546,7 @@ static bool LinearizeExprTree(Instruction *I, if (Ops.empty()) { Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType()); assert(Identity && "Associative operation without identity!"); - Ops.emplace_back(Identity, APInt(Bitwidth, 1)); + Ops.emplace_back(Identity, 1); } return Changed; @@ -1188,8 +1092,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { RepeatedValue E = Tree[i]; - Factors.append(E.second.getZExtValue(), - ValueEntry(getRank(E.first), E.first)); + Factors.append(E.second, ValueEntry(getRank(E.first), E.first)); } bool FoundFactor = false; @@ -2368,7 +2271,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { SmallVector Ops; Ops.reserve(Tree.size()); for (const RepeatedValue &E : Tree) - Ops.append(E.second.getZExtValue(), ValueEntry(getRank(E.first), E.first)); + Ops.append(E.second, ValueEntry(getRank(E.first), E.first)); LLVM_DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n'); diff --git a/llvm/test/Transforms/Reassociate/repeats.ll b/llvm/test/Transforms/Reassociate/repeats.ll index ba25c4bfc643c..8600777877bb3 100644 --- a/llvm/test/Transforms/Reassociate/repeats.ll +++ b/llvm/test/Transforms/Reassociate/repeats.ll @@ -60,7 +60,8 @@ define i3 @foo3x5(i3 %x) { ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]] ; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]] -; CHECK-NEXT: ret i3 [[TMP4]] +; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret i3 [[TMP5]] ; %tmp1 = mul i3 %x, %x %tmp2 = mul i3 %tmp1, %x @@ -74,7 +75,8 @@ define i3 @foo3x5_nsw(i3 %x) { ; CHECK-LABEL: define i3 @foo3x5_nsw( ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]] ; CHECK-NEXT: ret i3 [[TMP4]] ; %tmp1 = mul i3 %x, %x @@ -89,7 +91,8 @@ define i3 @foo3x6(i3 %x) { ; CHECK-LABEL: define i3 @foo3x6( ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i3 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]] ; CHECK-NEXT: ret i3 [[TMP2]] ; %tmp1 = mul i3 %x, %x @@ -106,7 +109,9 @@ define i3 @foo3x7(i3 %x) { ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[X]], [[X]] ; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]] -; CHECK-NEXT: ret i3 [[TMP6]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP6]], [[X]] +; CHECK-NEXT: [[TMP7:%.*]] = mul i3 [[TMP3]], [[TMP6]] +; CHECK-NEXT: ret i3 [[TMP7]] ; %tmp1 = mul i3 %x, %x %tmp2 = mul i3 %tmp1, %x @@ -123,7 +128,8 @@ define i4 @foo4x8(i4 %x) { ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] ; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] -; CHECK-NEXT: ret i4 [[TMP4]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]] +; CHECK-NEXT: ret i4 [[TMP3]] ; %tmp1 = mul i4 %x, %x %tmp2 = mul i4 %tmp1, %x @@ -140,8 +146,9 @@ define i4 @foo4x9(i4 %x) { ; CHECK-LABEL: define i4 @foo4x9( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]] +; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]] ; CHECK-NEXT: ret i4 [[TMP8]] ; %tmp1 = mul i4 %x, %x @@ -160,7 +167,8 @@ define i4 @foo4x10(i4 %x) { ; CHECK-LABEL: define i4 @foo4x10( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]] ; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]] ; CHECK-NEXT: ret i4 [[TMP3]] ; @@ -181,7 +189,8 @@ define i4 @foo4x11(i4 %x) { ; CHECK-LABEL: define i4 @foo4x11( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]] ; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]] ; CHECK-NEXT: [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]] ; CHECK-NEXT: ret i4 [[TMP10]] @@ -204,7 +213,9 @@ define i4 @foo4x12(i4 %x) { ; CHECK-LABEL: define i4 @foo4x12( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]] ; CHECK-NEXT: ret i4 [[TMP2]] ; %tmp1 = mul i4 %x, %x @@ -227,7 +238,9 @@ define i4 @foo4x13(i4 %x) { ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] ; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]] +; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]] ; CHECK-NEXT: ret i4 [[TMP12]] ; %tmp1 = mul i4 %x, %x @@ -252,7 +265,9 @@ define i4 @foo4x14(i4 %x) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] ; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]] ; CHECK-NEXT: [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]] -; CHECK-NEXT: ret i4 [[TMP7]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP7]], [[X]] +; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]] +; CHECK-NEXT: ret i4 [[TMP5]] ; %tmp1 = mul i4 %x, %x %tmp2 = mul i4 %tmp1, %x @@ -276,8 +291,10 @@ define i4 @foo4x15(i4 %x) { ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] ; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]] -; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP6]], [[TMP6]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]] +; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[X]] +; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP4]] ; CHECK-NEXT: ret i4 [[TMP14]] ; %tmp1 = mul i4 %x, %x From cd30a7a533174f6a4bf03635cd9cf4365410e944 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sat, 8 Jun 2024 00:13:40 +0800 Subject: [PATCH 2/2] [Reassociate] Add overflow checks. --- llvm/lib/Transforms/Scalar/Reassociate.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 6cf097094ddd0..f36e21b296bd1 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -470,6 +470,7 @@ static bool LinearizeExprTree(Instruction *I, // Update the number of paths to the leaf. It->second += Weight; + assert(It->second >= Weight && "Weight overflows"); // If we still have uses that are not accounted for by the expression // then it is not safe to modify the value.