Skip to content

Commit 92b481d

Browse files
committed
[llvm][IR] Extend BranchWeightMetadata to track provenance of weights
This patch implements the changes to LLVM IR discussed in https://discourse.llvm.org/t/rfc-update-branch-weights-metadata-to-allow-tracking-branch-weight-origins/75032 In this patch, we add an optional field to MD_prof meatdata nodes for branch weights, which can be used to distinguish weights added from `llvm.expect*` intrinsics from those added via other methods, e.g. from profiles or inserted by the compiler. One of the major motivations, is for use with MisExpect diagnostics, which need to know if branch_weight metadata originates from an llvm.expect intrinsic. Without that information, we end up checking branch weights multiple times in the case if ThinLTO + SampleProfiling, leading to some inaccuracy in how we report MisExpect related diagnostics to users. Since we change the format of MD_prof metadata in a fundamental way, we need to update code handling branch weights in a number of places. We also update the lang ref for branch weights to reflect the change. Pull Request: llvm#86609
1 parent 8ffd962 commit 92b481d

28 files changed

+175
-86
lines changed

clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,5 @@ void tu2(int &i) {
221221
}
222222
}
223223

224-
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
225-
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
224+
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
225+
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}

llvm/docs/BranchWeightMetadata.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ Supported Instructions
2828

2929
Metadata is only assigned to the conditional branches. There are two extra
3030
operands for the true and the false branch.
31+
We optionally track if the metadata was added by ``__builtin_expect`` or
32+
``__builtin_expect_with_probability`` with an optional field ``!"expected"``.
3133

3234
.. code-block:: none
3335
3436
!0 = !{
3537
!"branch_weights",
38+
[ !"expected", ]
3639
i32 <TRUE_BRANCH_WEIGHT>,
3740
i32 <FALSE_BRANCH_WEIGHT>
3841
}
@@ -47,6 +50,7 @@ is always case #0).
4750
4851
!0 = !{
4952
!"branch_weights",
53+
[ !"expected", ]
5054
i32 <DEFAULT_BRANCH_WEIGHT>
5155
[ , i32 <CASE_BRANCH_WEIGHT> ... ]
5256
}
@@ -60,6 +64,7 @@ Branch weights are assigned to every destination.
6064
6165
!0 = !{
6266
!"branch_weights",
67+
[ !"expected", ]
6368
i32 <LABEL_BRANCH_WEIGHT>
6469
[ , i32 <LABEL_BRANCH_WEIGHT> ... ]
6570
}
@@ -75,6 +80,7 @@ block and entry counts which may not be accurate with sampling.
7580
7681
!0 = !{
7782
!"branch_weights",
83+
[ !"expected", ]
7884
i32 <CALL_BRANCH_WEIGHT>
7985
}
8086
@@ -95,6 +101,7 @@ is used.
95101
96102
!0 = !{
97103
!"branch_weights",
104+
[ !"expected", ]
98105
i32 <INVOKE_NORMAL_WEIGHT>
99106
[ , i32 <INVOKE_UNWIND_WEIGHT> ]
100107
}

llvm/include/llvm/IR/MDBuilder.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,17 @@ class MDBuilder {
5959
//===------------------------------------------------------------------===//
6060

6161
/// Return metadata containing two branch weights.
62-
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight);
62+
/// @param TrueWeight the weight of the true branch
63+
/// @param FalseWeight the weight of the false branch
64+
/// @param Do these weights come from __builtin_expect*
65+
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight,
66+
bool IsExpected = false);
6367

6468
/// Return metadata containing a number of branch weights.
65-
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights);
69+
/// @param Weights the weights of all the branches
70+
/// @param Do these weights come from __builtin_expect*
71+
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights,
72+
bool IsExpected = false);
6673

6774
/// Return metadata specifying that a branch or switch is unpredictable.
6875
MDNode *createUnpredictable();

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ MDNode *getBranchWeightMDNode(const Instruction &I);
5555
/// Nullptr otherwise.
5656
MDNode *getValidBranchWeightMDNode(const Instruction &I);
5757

58+
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
59+
/// intrinsic
60+
bool hasExpectedProvenance(const Instruction &I);
61+
62+
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
63+
/// intrinsic
64+
bool hasExpectedProvenance(const MDNode *ProfileData);
65+
66+
/// Return the offset to the first branch weight data
67+
unsigned getBranchWeightOffset(const Instruction &I);
68+
69+
/// Return the offset to the first branch weight data
70+
unsigned getBranchWeightOffset(const MDNode *ProfileData);
71+
5872
/// Extract branch weights from MD_prof metadata
5973
///
6074
/// \param ProfileData A pointer to an MDNode.
@@ -111,7 +125,11 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
111125

112126
/// Create a new `branch_weights` metadata node and add or overwrite
113127
/// a `prof` metadata reference to instruction `I`.
114-
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
128+
/// \param I the Instruction to set branch weights on.
129+
/// \param Weights an array of weights to set on instruction I.
130+
/// \param IsExpected were these weights added from an llvm.expect* intrinsic.
131+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
132+
bool IsExpected);
115133

116134
} // namespace llvm
117135
#endif

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8843,7 +8843,8 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) {
88438843
scaleWeights(NewTrueWeight, NewFalseWeight);
88448844
Br1->setMetadata(LLVMContext::MD_prof,
88458845
MDBuilder(Br1->getContext())
8846-
.createBranchWeights(TrueWeight, FalseWeight));
8846+
.createBranchWeights(TrueWeight, FalseWeight,
8847+
hasExpectedProvenance(*Br1)));
88478848

88488849
NewTrueWeight = TrueWeight;
88498850
NewFalseWeight = 2 * FalseWeight;

llvm/lib/IR/Instruction.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,12 +1210,22 @@ Instruction *Instruction::cloneImpl() const {
12101210

12111211
void Instruction::swapProfMetadata() {
12121212
MDNode *ProfileData = getBranchWeightMDNode(*this);
1213-
if (!ProfileData || ProfileData->getNumOperands() != 3)
1213+
if (!isBranchWeightMD(ProfileData))
12141214
return;
12151215

1216-
// The first operand is the name. Fetch them backwards and build a new one.
1217-
Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
1218-
ProfileData->getOperand(1)};
1216+
SmallVector<Metadata *, 4> Ops;
1217+
unsigned int FirstIdx = getBranchWeightOffset(ProfileData);
1218+
unsigned int SecondIdx = FirstIdx + 1;
1219+
// If there are more weights past the second, we can't swap them
1220+
if (ProfileData->getNumOperands() > SecondIdx + 1)
1221+
return;
1222+
Ops.push_back(ProfileData->getOperand(0));
1223+
if (hasExpectedProvenance(ProfileData)) {
1224+
Ops.push_back(ProfileData->getOperand(1));
1225+
}
1226+
// Switch the order of the weights
1227+
Ops.push_back(ProfileData->getOperand(SecondIdx));
1228+
Ops.push_back(ProfileData->getOperand(FirstIdx));
12191229
setMetadata(LLVMContext::MD_prof,
12201230
MDNode::get(ProfileData->getContext(), Ops));
12211231
}

llvm/lib/IR/Instructions.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,10 +857,12 @@ void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
857857
APInt APS(128, S), APT(128, T);
858858
if (ProfDataName->getString().equals("branch_weights") &&
859859
ProfileData->getNumOperands() > 0) {
860+
unsigned int Offset = getBranchWeightOffset(ProfileData);
860861
// Using APInt::div may be expensive, but most cases should fit 64 bits.
861-
APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
862-
->getValue()
863-
.getZExtValue());
862+
APInt Val(128,
863+
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Offset))
864+
->getValue()
865+
.getZExtValue());
864866
Val *= APS;
865867
Vals.push_back(MDB.createConstant(
866868
ConstantInt::get(Type::getInt32Ty(getContext()),
@@ -5196,7 +5198,11 @@ void SwitchInstProfUpdateWrapper::init() {
51965198
if (!ProfileData)
51975199
return;
51985200

5199-
if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
5201+
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
5202+
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable
5203+
// makes them slightly different.
5204+
if (ProfileData->getNumOperands() !=
5205+
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
52005206
llvm_unreachable("number of prof branch_weights metadata operands does "
52015207
"not correspond to number of succesors");
52025208
}

llvm/lib/IR/MDBuilder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,23 @@ MDNode *MDBuilder::createFPMath(float Accuracy) {
3535
}
3636

3737
MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
38-
uint32_t FalseWeight) {
39-
return createBranchWeights({TrueWeight, FalseWeight});
38+
uint32_t FalseWeight, bool IsExpected) {
39+
return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
4040
}
4141

42-
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
42+
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
43+
bool IsExpected) {
4344
assert(Weights.size() >= 1 && "Need at least one branch weights!");
4445

45-
SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
46+
unsigned int Offset = IsExpected ? 2 : 1;
47+
SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
4648
Vals[0] = createString("branch_weights");
49+
if (IsExpected)
50+
Vals[1] = createString("expected");
4751

4852
Type *Int32Ty = Type::getInt32Ty(Context);
4953
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
50-
Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
54+
Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
5155

5256
return MDNode::get(Context, Vals);
5357
}

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ namespace {
4040
// We maintain some constants here to ensure that we access the branch weights
4141
// correctly, and can change the behavior in the future if the layout changes
4242

43-
// The index at which the weights vector starts
44-
constexpr unsigned WeightsIdx = 1;
45-
4643
// the minimum number of operands for MD_prof nodes with branch weights
4744
constexpr unsigned MinBWOps = 3;
4845

@@ -72,15 +69,16 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
7269
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
7370

7471
unsigned NOps = ProfileData->getNumOperands();
72+
unsigned int WeightsIdx = getBranchWeightOffset(ProfileData);
7573
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
7674
Weights.resize(NOps - WeightsIdx);
7775

7876
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
7977
ConstantInt *Weight =
8078
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
8179
assert(Weight && "Malformed branch_weight in MD_prof node");
82-
assert(Weight->getValue().getActiveBits() <= 32 &&
83-
"Too many bits for uint32_t");
80+
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
81+
"Too many bits for MD_prof branch_weight");
8482
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
8583
}
8684
}
@@ -106,6 +104,30 @@ bool hasValidBranchWeightMD(const Instruction &I) {
106104
return getValidBranchWeightMDNode(I);
107105
}
108106

107+
bool hasExpectedProvenance(const Instruction &I) {
108+
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
109+
return hasExpectedProvenance(ProfileData);
110+
}
111+
112+
bool hasExpectedProvenance(const MDNode *ProfileData) {
113+
if (!isBranchWeightMD(ProfileData))
114+
return false;
115+
116+
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
117+
if (!ProfDataName)
118+
return false;
119+
return ProfDataName->getString().equals("expected");
120+
}
121+
122+
unsigned getBranchWeightOffset(const Instruction &I) {
123+
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
124+
return getBranchWeightOffset(ProfileData);
125+
}
126+
127+
unsigned getBranchWeightOffset(const MDNode *ProfileData) {
128+
return hasExpectedProvenance(ProfileData) ? 2 : 1;
129+
}
130+
109131
MDNode *getBranchWeightMDNode(const Instruction &I) {
110132
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
111133
if (!isBranchWeightMD(ProfileData))
@@ -115,7 +137,9 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
115137

116138
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
117139
auto *ProfileData = getBranchWeightMDNode(I);
118-
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
140+
auto Offset = getBranchWeightOffset(ProfileData);
141+
if (ProfileData &&
142+
ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
119143
return ProfileData;
120144
return nullptr;
121145
}
@@ -174,7 +198,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
174198
return false;
175199

176200
if (ProfDataName->getString().equals("branch_weights")) {
177-
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
201+
unsigned int Offset = getBranchWeightOffset(ProfileData);
202+
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
178203
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
179204
assert(V && "Malformed branch_weight in MD_prof node");
180205
TotalVal += V->getValue().getZExtValue();
@@ -196,9 +221,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
196221
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
197222
}
198223

199-
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
224+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
225+
bool IsExpected) {
200226
MDBuilder MDB(I.getContext());
201-
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
227+
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
202228
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
203229
}
204230

llvm/lib/IR/Verifier.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
#include "llvm/IR/Module.h"
104104
#include "llvm/IR/ModuleSlotTracker.h"
105105
#include "llvm/IR/PassManager.h"
106+
#include "llvm/IR/ProfDataUtils.h"
106107
#include "llvm/IR/Statepoint.h"
107108
#include "llvm/IR/Type.h"
108109
#include "llvm/IR/Use.h"
@@ -4756,8 +4757,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
47564757

47574758
// Check consistency of !prof branch_weights metadata.
47584759
if (ProfName.equals("branch_weights")) {
4760+
unsigned int Offset = getBranchWeightOffset(I);
47594761
if (isa<InvokeInst>(&I)) {
4760-
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
4762+
Check(MD->getNumOperands() == (1 + Offset) ||
4763+
MD->getNumOperands() == (2 + Offset),
47614764
"Wrong number of InvokeInst branch_weights operands", MD);
47624765
} else {
47634766
unsigned ExpectedNumOperands = 0;
@@ -4777,10 +4780,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
47774780
CheckFailed("!prof branch_weights are not allowed for this instruction",
47784781
MD);
47794782

4780-
Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
4783+
Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
47814784
"Wrong number of operands", MD);
47824785
}
4783-
for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
4786+
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
47844787
auto &MDO = MD->getOperand(i);
47854788
Check(MDO, "second operand should not be null", MD);
47864789
Check(mdconst::dyn_extract<ConstantInt>(MDO),

llvm/lib/Transforms/IPO/SampleProfile.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
17751775
else if (OverwriteExistingWeights)
17761776
I.setMetadata(LLVMContext::MD_prof, nullptr);
17771777
} else if (!isa<IntrinsicInst>(&I)) {
1778-
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
1778+
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])},
1779+
/*IsExpected=*/false);
17791780
}
17801781
}
17811782
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1786,7 +1787,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
17861787
if (cast<CallBase>(I).isIndirectCall()) {
17871788
I.setMetadata(LLVMContext::MD_prof, nullptr);
17881789
} else {
1789-
setBranchWeights(I, {uint32_t(0)});
1790+
setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false);
17901791
}
17911792
}
17921793
}
@@ -1867,7 +1868,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
18671868
if (MaxWeight > 0 &&
18681869
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
18691870
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
1870-
setBranchWeights(*TI, Weights);
1871+
setBranchWeights(*TI, Weights, /*IsExpected=*/false);
18711872
ORE->emit([&]() {
18721873
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
18731874
<< "most popular destination for conditional branches at "

llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1878,7 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
18781878
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
18791879
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
18801880
};
1881-
setBranchWeights(*MergedBR, Weights);
1881+
setBranchWeights(*MergedBR, Weights, /*IsExpected=*/false);
18821882
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
18831883
<< "\n");
18841884
}

llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
257257
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
258258

259259
if (AttachProfToDirectCall) {
260-
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
260+
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)},
261+
/*IsExpected=*/false);
261262
}
262263

263264
using namespace ore;

llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,8 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
14161416
for (auto *Succ : successors(&BB))
14171417
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
14181418
if (Weights.size() >= 2)
1419-
llvm::setBranchWeights(*BB.getTerminator(), Weights);
1419+
llvm::setBranchWeights(*BB.getTerminator(), Weights,
1420+
/*IsExpected=*/false);
14201421
}
14211422

14221423
unsigned NumCorruptCoverage = 0;
@@ -2191,7 +2192,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
21912192

21922193
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
21932194

2194-
setBranchWeights(*TI, Weights);
2195+
setBranchWeights(*TI, Weights, /*IsExpected=*/false);
21952196
if (EmitBranchProbability) {
21962197
std::string BrCondStr = getBranchCondString(TI);
21972198
if (BrCondStr.empty())

0 commit comments

Comments
 (0)