From 1e9dbac8b5578a3d10228466177f2d7c605978ba Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Tue, 26 Mar 2024 00:49:16 +0000 Subject: [PATCH 1/4] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- llvm/include/llvm/IR/ProfDataUtils.h | 9 +++- llvm/lib/IR/ProfDataUtils.cpp | 45 +++++++++++++------ .../Transforms/Utils/LoopRotationUtils.cpp | 2 +- llvm/lib/Transforms/Utils/MisExpect.cpp | 3 +- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 7 +-- 5 files changed, 43 insertions(+), 23 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 255fa2ff1c790..dc983eed13a8d 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights); /// Faster version of extractBranchWeights() that skips checks and must only -/// be called with "branch_weights" metadata nodes. -void extractFromBranchWeightMD(const MDNode *ProfileData, +/// be called with "branch_weights" metadata nodes. Supports uint32_t. +void extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl &Weights); +/// Faster version of extractBranchWeights() that skips checks and must only +/// be called with "branch_weights" metadata nodes. Supports uint64_t. +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights); + /// Extract branch weights attatched to an Instruction /// /// \param I The Instruction to extract weights from. diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b1a10d0ce5a52..b4e09e76993f9 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { return ProfDataName->getString().equals(Name); } +template >> +static void extractFromBranchWeightMD(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + assert(isBranchWeightMD(ProfileData) && "wrong metadata"); + + unsigned NOps = ProfileData->getNumOperands(); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } +} + } // namespace namespace llvm { @@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { return nullptr; } -void extractFromBranchWeightMD(const MDNode *ProfileData, +void extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl &Weights) { - assert(isBranchWeightMD(ProfileData) && "wrong metadata"); - - unsigned NOps = ProfileData->getNumOperands(); - assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); - Weights.resize(NOps - WeightsIdx); + extractFromBranchWeightMD(ProfileData, Weights); +} - for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { - ConstantInt *Weight = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights[Idx - WeightsIdx] = Weight->getZExtValue(); - } +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); } + + + + + bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index bc67117113719..f4b43ce370a5d 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, return; SmallVector Weights; - extractFromBranchWeightMD(WeightMD, Weights); + extractFromBranchWeightMD32(WeightMD, Weights); if (Weights.size() != 2) return; uint32_t OrigLoopExitWeight = Weights[0]; diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp index 6f5a25a26821b..759289384ee06 100644 --- a/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -59,9 +59,10 @@ static cl::opt PGOWarnMisExpect( cl::desc("Use this option to turn on/off " "warnings about incorrect usage of llvm.expect intrinsics.")); +// Command line option for setting the diagnostic tolerance threshold static cl::opt MisExpectTolerance( "misexpect-tolerance", cl::init(0), - cl::desc("Prevents emiting diagnostics when profile counts are " + cl::desc("Prevents emitting diagnostics when profile counts are " "within N% of the threshold..")); } // namespace llvm diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 55bbffb18879f..a425e26d490e4 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, static void GetBranchWeights(Instruction *TI, SmallVectorImpl &Weights) { MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); - assert(MD); - for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = mdconst::extract(MD->getOperand(i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + assert(MD && "Invalid branch-weight metadata"); + extractFromBranchWeightMD64(MD, Weights); // If TI is a conditional eq, the default case is the false case, // and the corresponding branch-weight data is at index 2. We swap the From ee503bf633c47f9b8abb76848327bcbd2b769be3 Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Tue, 26 Mar 2024 00:56:29 +0000 Subject: [PATCH 2/4] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20introduced=20through=20rebase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- llvm/include/llvm/IR/ProfDataUtils.h | 4 ++-- llvm/lib/IR/ProfDataUtils.cpp | 9 ++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index dc983eed13a8d..457ffdff8fe37 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -67,12 +67,12 @@ bool extractBranchWeights(const MDNode *ProfileData, /// Faster version of extractBranchWeights() that skips checks and must only /// be called with "branch_weights" metadata nodes. Supports uint32_t. void extractFromBranchWeightMD32(const MDNode *ProfileData, - SmallVectorImpl &Weights); + SmallVectorImpl &Weights); /// Faster version of extractBranchWeights() that skips checks and must only /// be called with "branch_weights" metadata nodes. Supports uint64_t. void extractFromBranchWeightMD64(const MDNode *ProfileData, - SmallVectorImpl &Weights); + SmallVectorImpl &Weights); /// Extract branch weights attatched to an Instruction /// diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b4e09e76993f9..36e165e641f46 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -121,20 +121,15 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { } void extractFromBranchWeightMD32(const MDNode *ProfileData, - SmallVectorImpl &Weights) { + SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } void extractFromBranchWeightMD64(const MDNode *ProfileData, - SmallVectorImpl &Weights) { + SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } - - - - - bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) From 7760282ed8dba340d6873d06ff4c18c6efc25b56 Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Thu, 6 Jun 2024 19:04:19 +0000 Subject: [PATCH 3/4] Add assert for metadata string value Created using spr 1.3.4 --- llvm/lib/IR/ProfDataUtils.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index f738d76937c24..af536d2110eac 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -133,6 +133,7 @@ bool hasBranchWeightProvenance(const MDNode *ProfileData) { // NOTE: if we ever have more types of branch weight provenance, // we need to check the string value is "expected". For now, we // supply a more generic API, and avoid the spurious comparisons. + assert(ProfDataName->getString() == "expected"); return ProfDataName; } From 947f9e16732197418c9f49ed02a01b187a50f936 Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Thu, 6 Jun 2024 21:13:24 +0000 Subject: [PATCH 4/4] Rename hasBranchWeightProvenance to hasBranchWeightOrigin Created using spr 1.3.4 --- llvm/include/llvm/IR/ProfDataUtils.h | 4 ++-- llvm/lib/CodeGen/CodeGenPrepare.cpp | 9 ++++----- llvm/lib/IR/Metadata.cpp | 8 ++++---- llvm/lib/IR/ProfDataUtils.cpp | 14 +++++++------- llvm/lib/IR/Verifier.cpp | 2 +- llvm/lib/Transforms/Scalar/JumpThreading.cpp | 4 ++-- llvm/lib/Transforms/Utils/Local.cpp | 2 +- 7 files changed, 21 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 3c761bdc1bf3e..1d7c97d9be953 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -57,11 +57,11 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I); /// Check if Branch Weight Metadata has an "expected" field from an llvm.expect* /// intrinsic -bool hasBranchWeightProvenance(const Instruction &I); +bool hasBranchWeightOrigin(const Instruction &I); /// Check if Branch Weight Metadata has an "expected" field from an llvm.expect* /// intrinsic -bool hasBranchWeightProvenance(const MDNode *ProfileData); +bool hasBranchWeightOrigin(const MDNode *ProfileData); /// Return the offset to the first branch weight data unsigned getBranchWeightOffset(const MDNode *ProfileData); diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index ce11caca38988..0e01080bd75cc 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -8864,11 +8864,10 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) { uint64_t NewTrueWeight = TrueWeight; uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight; scaleWeights(NewTrueWeight, NewFalseWeight); - Br1->setMetadata( - LLVMContext::MD_prof, - MDBuilder(Br1->getContext()) - .createBranchWeights(TrueWeight, FalseWeight, - hasBranchWeightProvenance(*Br1))); + Br1->setMetadata(LLVMContext::MD_prof, + MDBuilder(Br1->getContext()) + .createBranchWeights(TrueWeight, FalseWeight, + hasBranchWeightOrigin(*Br1))); NewTrueWeight = TrueWeight; NewFalseWeight = 2 * FalseWeight; diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index b6c932495a145..5f42ce22f72fe 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1196,10 +1196,10 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B, StringRef AProfName = AMDS->getString(); StringRef BProfName = BMDS->getString(); if (AProfName == "branch_weights" && BProfName == "branch_weights") { - ConstantInt *AInstrWeight = - mdconst::dyn_extract(A->getOperand(1)); - ConstantInt *BInstrWeight = - mdconst::dyn_extract(B->getOperand(1)); + ConstantInt *AInstrWeight = mdconst::dyn_extract( + A->getOperand(getBranchWeightOffset(A))); + ConstantInt *BInstrWeight = mdconst::dyn_extract( + B->getOperand(getBranchWeightOffset(B))); assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier"); return MDNode::get(Ctx, {MDHelper.createString("branch_weights"), diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index af536d2110eac..c4b1ed55de8a2 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -121,24 +121,24 @@ bool hasValidBranchWeightMD(const Instruction &I) { return getValidBranchWeightMDNode(I); } -bool hasBranchWeightProvenance(const Instruction &I) { +bool hasBranchWeightOrigin(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); - return hasBranchWeightProvenance(ProfileData); + return hasBranchWeightOrigin(ProfileData); } -bool hasBranchWeightProvenance(const MDNode *ProfileData) { +bool hasBranchWeightOrigin(const MDNode *ProfileData) { if (!isBranchWeightMD(ProfileData)) return false; auto *ProfDataName = dyn_cast(ProfileData->getOperand(1)); // NOTE: if we ever have more types of branch weight provenance, // we need to check the string value is "expected". For now, we // supply a more generic API, and avoid the spurious comparisons. - assert(ProfDataName->getString() == "expected"); - return ProfDataName; + assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); + return ProfDataName != nullptr; } unsigned getBranchWeightOffset(const MDNode *ProfileData) { - return hasBranchWeightProvenance(ProfileData) ? 2 : 1; + return hasBranchWeightOrigin(ProfileData) ? 2 : 1; } MDNode *getBranchWeightMDNode(const Instruction &I) { @@ -210,7 +210,7 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { if (!ProfDataName) return false; - if (ProfDataName->getString().equals("branch_weights")) { + if (ProfDataName->getString() == "branch_weights") { unsigned Offset = getBranchWeightOffset(ProfileData); for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { auto *V = mdconst::dyn_extract(ProfileData->getOperand(Idx)); diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 39185905a1516..e0fde2b7d90dc 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4808,7 +4808,7 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { StringRef ProfName = MDS->getString(); // Check consistency of !prof branch_weights metadata. - if (ProfName.equals("branch_weights")) { + if (ProfName == "branch_weights") { unsigned int Offset = getBranchWeightOffset(MD); if (isa(&I)) { Check(MD->getNumOperands() == (1 + Offset) || diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 88307b8b074ed..b9583836aea06 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -231,7 +231,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { Weights[0] = BP.getCompl().getNumerator(); Weights[1] = BP.getNumerator(); } - setBranchWeights(*PredBr, Weights, hasBranchWeightProvenance(*PredBr)); + setBranchWeights(*PredBr, Weights, hasBranchWeightOrigin(*PredBr)); } } @@ -2618,7 +2618,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, Weights.push_back(Prob.getNumerator()); auto TI = BB->getTerminator(); - setBranchWeights(*TI, Weights, hasBranchWeightProvenance(*TI)); + setBranchWeights(*TI, Weights, hasBranchWeightOrigin(*TI)); } } diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 8f116c42d3d78..12229123675e7 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -231,7 +231,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Remove weight for this case. std::swap(Weights[Idx + 1], Weights.back()); Weights.pop_back(); - setBranchWeights(*SI, Weights, hasBranchWeightProvenance(MD)); + setBranchWeights(*SI, Weights, hasBranchWeightOrigin(MD)); } // Remove this entry. BasicBlock *ParentBB = SI->getParent();