Skip to content

[llvm][ir] Purge MD_prof custom accessors #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions llvm/include/llvm/Analysis/CFGPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/DOTGraphTraits.h"
#include "llvm/Support/FormatVariadic.h"

Expand Down Expand Up @@ -276,14 +277,10 @@ struct DOTGraphTraits<DOTFuncInfo *> : public DefaultDOTGraphTraits {
if (Attrs.size())
return Attrs;

MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
MDNode *WeightsNode = getBranchWeightMDNode(*TI);
if (!WeightsNode)
return "";

MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
if (MDName->getString() != "branch_weights")
return "";

OpNo = I.getSuccessorIndex() + 1;
if (OpNo >= WeightsNode->getNumOperands())
return "";
Expand Down
2 changes: 0 additions & 2 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3620,8 +3620,6 @@ class SwitchInstProfUpdateWrapper {
bool Changed = false;

protected:
static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);

MDNode *buildProfBranchWeightsMD();

void init();
Expand Down
29 changes: 29 additions & 0 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,27 @@ bool isBranchWeightMD(const MDNode *ProfileData);
/// otherwise.
bool hasBranchWeightMD(const Instruction &I);

/// Checks if an instructions has valid Branch Weight Metadata
///
/// \param I The instruction to check
/// \returns True if I has an MD_prof node containing valid Branch Weights,
/// i.e., one weight for each successor. False otherwise.
bool hasValidBranchWeightMD(const Instruction &I);

/// Get the branch weights metadata node
///
/// \param I The Instruction to get the weights from.
/// \returns A pointer to I's branch weights metadata node, if it exists.
/// Nullptr otherwise.
MDNode *getBranchWeightMDNode(const Instruction &I);

/// Get the valid branch weights metadata node
///
/// \param I The Instruction to get the weights from.
/// \returns A pointer to I's valid branch weights metadata node, if it exists.
/// Nullptr otherwise.
MDNode *getValidBranchWeightMDNode(const Instruction &I);

/// Extract branch weights from MD_prof metadata
///
/// \param ProfileData A pointer to an MDNode.
Expand Down Expand Up @@ -70,5 +91,13 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
/// metadata was found.
bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);

/// Retrieve the total of all weights from an instruction.
///
/// \param I The instruction to extract the total weight from
/// \param [out] TotalWeights input variable to fill with total weights
/// \returns True on success with profile total weights filled in. False if no
/// metadata was found.
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);

} // namespace llvm
#endif
9 changes: 2 additions & 7 deletions llvm/lib/Analysis/BranchProbabilityInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,13 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
isa<InvokeInst>(TI) || isa<CallBrInst>(TI)))
return false;

MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
if (!WeightsNode)
return false;

// Check that the number of successors is manageable.
assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");

// Ensure there are weights for all of the successors. Note that the first
// operand to the metadata node is a name, not a weight.
if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
return false;

// Build up the final weights that will be used in a temporary buffer.
// Compute the sum of all weights to later decide whether they need to
// be scaled to fit in 32 bits.
Expand All @@ -403,7 +398,7 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
SmallVector<unsigned, 2> UnreachableIdxs;
SmallVector<unsigned, 2> ReachableIdxs;

extractBranchWeights(*TI, Weights);
extractBranchWeights(WeightsNode, Weights);
for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
WeightSum += Weights[I];
const LoopBlock SrcLoopBB = getLoopBlock(BB);
Expand Down
10 changes: 3 additions & 7 deletions llvm/lib/IR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
using namespace llvm;

Expand Down Expand Up @@ -855,13 +856,8 @@ Instruction *Instruction::cloneImpl() const {
}

void Instruction::swapProfMetadata() {
MDNode *ProfileData = getMetadata(LLVMContext::MD_prof);
if (!ProfileData || ProfileData->getNumOperands() != 3 ||
!isa<MDString>(ProfileData->getOperand(0)))
return;

MDString *MDName = cast<MDString>(ProfileData->getOperand(0));
if (MDName->getString() != "branch_weights")
MDNode *ProfileData = getBranchWeightMDNode(*this);
if (!ProfileData || ProfileData->getNumOperands() != 3)
return;

// The first operand is the name. Fetch them backwards and build a new one.
Expand Down
21 changes: 5 additions & 16 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/AtomicOrdering.h"
Expand Down Expand Up @@ -4572,15 +4573,6 @@ void SwitchInst::growOperands() {
growHungoffUses(ReservedSpace);
}

MDNode *
SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) {
if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof))
if (auto *MDName = dyn_cast<MDString>(ProfileData->getOperand(0)))
if (MDName->getString() == "branch_weights")
return ProfileData;
return nullptr;
}

MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
assert(Changed && "called only if metadata has changed");

Expand All @@ -4599,7 +4591,7 @@ MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
}

void SwitchInstProfUpdateWrapper::init() {
MDNode *ProfileData = getProfBranchWeightsMD(SI);
MDNode *ProfileData = getBranchWeightMDNode(SI);
if (!ProfileData)
return;

Expand All @@ -4609,11 +4601,8 @@ void SwitchInstProfUpdateWrapper::init() {
}

SmallVector<uint32_t, 8> Weights;
for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
uint32_t CW = C->getValue().getZExtValue();
Weights.push_back(CW);
}
if (!extractBranchWeights(ProfileData, Weights))
return;
this->Weights = std::move(Weights);
}

Expand Down Expand Up @@ -4686,7 +4675,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight(
SwitchInstProfUpdateWrapper::CaseWeightOpt
SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
unsigned idx) {
if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
if (MDNode *ProfileData = getBranchWeightMDNode(SI))
if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1)
return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
->getValue()
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
getOpcode() == Instruction::Switch) &&
"Looking for branch weights on something besides branch");

return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
return ::extractProfTotalWeight(*this, TotalVal);
}

void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {
Expand Down
29 changes: 28 additions & 1 deletion llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ bool hasBranchWeightMD(const Instruction &I) {
return isBranchWeightMD(ProfileData);
}

bool hasValidBranchWeightMD(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
return false;
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
return true;
return false;
}

MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
return nullptr;
return ProfileData;
}

MDNode *getValidBranchWeightMDNode(const Instruction &I) {
if (!hasValidBranchWeightMD(I))
return nullptr;
return I.getMetadata(LLVMContext::MD_prof);
}

bool extractBranchWeights(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights) {
if (!isBranchWeightMD(ProfileData))
Expand All @@ -118,7 +140,8 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
uint64_t &FalseVal) {
assert((I.getOpcode() == Instruction::Br ||
I.getOpcode() == Instruction::Select) &&
"Looking for branch weights on something besides branch or select");
"Looking for branch weights on something besides branch, select, or "
"switch");

SmallVector<uint32_t, 2> Weights;
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
Expand Down Expand Up @@ -161,4 +184,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;
}

bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}

} // namespace llvm
3 changes: 1 addition & 2 deletions llvm/lib/Transforms/IPO/PartialInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
BranchInst *BR = dyn_cast<BranchInst>(E->getTerminator());
if (!BR || BR->isUnconditional())
continue;
uint64_t T, F;
if (extractBranchWeights(*BR, T, F))
if (hasBranchWeightMD(*BR))
return true;
}
return false;
Expand Down
39 changes: 16 additions & 23 deletions llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -575,32 +576,26 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT,
return true;
}

// Returns true and sets the true probability and false probability of an
// MD_prof metadata if it's well-formed.
static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb,
BranchProbability &FalseProb) {
if (!MD) return false;
MDString *MDName = cast<MDString>(MD->getOperand(0));
if (MDName->getString() != "branch_weights" ||
MD->getNumOperands() != 3)
// Constructs the true and false branch probabilities if the the instruction has
// valid branch weights. Returns true when this was successful, false otherwise.
static bool extractBranchProbabilities(Instruction *I,
BranchProbability &TrueProb,
BranchProbability &FalseProb) {
uint64_t TrueWeight;
uint64_t FalseWeight;
if (!extractBranchWeights(*I, TrueWeight, FalseWeight))
return false;
ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1));
ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2));
if (!TrueWeight || !FalseWeight)
return false;
uint64_t TrueWt = TrueWeight->getValue().getZExtValue();
uint64_t FalseWt = FalseWeight->getValue().getZExtValue();
uint64_t SumWt = TrueWt + FalseWt;
uint64_t SumWeight = TrueWeight + FalseWeight;

assert(SumWt >= TrueWt && SumWt >= FalseWt &&
assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight &&
"Overflow calculating branch probabilities.");

// Guard against 0-to-0 branch weights to avoid a division-by-zero crash.
if (SumWt == 0)
if (SumWeight == 0)
return false;

TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight);
FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight);
return true;
}

Expand Down Expand Up @@ -639,8 +634,7 @@ static bool checkBiasedBranch(BranchInst *BI, Region *R,
if (!BI->isConditional())
return false;
BranchProbability ThenProb, ElseProb;
if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof),
ThenProb, ElseProb))
if (!extractBranchProbabilities(BI, ThenProb, ElseProb))
return false;
BasicBlock *IfThen = BI->getSuccessor(0);
BasicBlock *IfElse = BI->getSuccessor(1);
Expand Down Expand Up @@ -669,8 +663,7 @@ static bool checkBiasedSelect(
DenseSet<SelectInst *> &FalseBiasedSelectsGlobal,
DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) {
BranchProbability TrueProb, FalseProb;
if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof),
TrueProb, FalseProb))
if (!extractBranchProbabilities(SI, TrueProb, FalseProb))
return false;
CHR_DEBUG(dbgs() << "SI " << *SI << " ");
CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " ");
Expand Down
13 changes: 1 addition & 12 deletions llvm/lib/Transforms/Scalar/JumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2522,18 +2522,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) {
const Instruction *TI = BB->getTerminator();
assert(TI->getNumSuccessors() > 1 && "not a split");

MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
if (!WeightsNode)
return false;

MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
if (MDName->getString() != "branch_weights")
return false;

// Ensure there are weights for all of the successors. Note that the first
// operand to the metadata node is a name, not a weight.
return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1;
return hasValidBranchWeightMD(*TI);
}

/// Update the block frequency of BB and branch weight and the metadata on the
Expand Down
30 changes: 9 additions & 21 deletions llvm/lib/Transforms/Scalar/LoopPredication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -974,37 +975,24 @@ bool LoopPredication::isLoopProfitableToPredicate() {
LatchExitBlock->getTerminatingDeoptimizeCall())
return false;

auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) {
if (!ProfileData || !ProfileData->getOperand(0))
return false;
if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
if (!MDS->getString().equals("branch_weights"))
return false;
if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
return false;
return true;
};
MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
// Latch terminator has no valid profile data, so nothing to check
// profitability on.
if (!IsValidProfileData(LatchProfileData, LatchTerm))
if (!hasValidBranchWeightMD(*LatchTerm))
return true;

auto ComputeBranchProbability =
[&](const BasicBlock *ExitingBlock,
const BasicBlock *ExitBlock) -> BranchProbability {
auto *Term = ExitingBlock->getTerminator();
MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
unsigned NumSucc = Term->getNumSuccessors();
if (IsValidProfileData(ProfileData, Term)) {
uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
for (unsigned i = 0; i < NumSucc; i++) {
ConstantInt *CI =
mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
ProfVal = CI->getValue().getZExtValue();
if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
SmallVector<uint32_t> Weights;
extractBranchWeights(ProfileData, Weights);
uint64_t Numerator = 0, Denominator = 0;
for (auto [i, Weight] : llvm::enumerate(Weights)) {
if (Term->getSuccessor(i) == ExitBlock)
Numerator += ProfVal;
Denominator += ProfVal;
Numerator += Weight;
Denominator += Weight;
}
return BranchProbability::getBranchProbability(Numerator, Denominator);
} else {
Expand Down
Loading