-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils #86607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils #86607
Conversation
Created using spr 1.3.4
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Paul Kirth (ilovepi) ChangesSince some places, like SimplifyCFG work with 64-bit weights, we supply an API We change the API slightly to disambiguate the 64 bit version from the 32 bit Full diff: https://github.com/llvm/llvm-project/pull/86607.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c7906..dc983eed13a8d3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);
/// Faster version of extractBranchWeights() that skips checks and must only
-/// be called with "branch_weights" metadata nodes.
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+/// be called with "branch_weights" metadata nodes. Supports uint32_t.
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);
+/// Faster version of extractBranchWeights() that skips checks and must only
+/// be called with "branch_weights" metadata nodes. Supports uint64_t.
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights);
+
/// Extract branch weights attatched to an Instruction
///
/// \param I The Instruction to extract weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a522..b4e09e76993f99 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
return ProfDataName->getString().equals(Name);
}
+template <typename T,
+ typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+ SmallVectorImpl<T> &Weights) {
+ assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+ unsigned NOps = ProfileData->getNumOperands();
+ assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+ Weights.resize(NOps - WeightsIdx);
+
+ for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+ ConstantInt *Weight =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+ assert(Weight && "Malformed branch_weight in MD_prof node");
+ assert(Weight->getValue().getActiveBits() <= 32 &&
+ "Too many bits for uint32_t");
+ Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+ }
+}
+
} // namespace
namespace llvm {
@@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
return nullptr;
}
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights) {
- assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
- unsigned NOps = ProfileData->getNumOperands();
- assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
- Weights.resize(NOps - WeightsIdx);
+ extractFromBranchWeightMD(ProfileData, Weights);
+}
- for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
- ConstantInt *Weight =
- mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
- assert(Weight && "Malformed branch_weight in MD_prof node");
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
- Weights[Idx - WeightsIdx] = Weight->getZExtValue();
- }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights) {
+ extractFromBranchWeightMD(ProfileData, Weights);
}
+
+
+
+
+
bool extractBranchWeights(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights) {
if (!isBranchWeightMD(ProfileData))
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index bc671171137199..f4b43ce370a5da 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
return;
SmallVector<uint32_t, 2> Weights;
- extractFromBranchWeightMD(WeightMD, Weights);
+ extractFromBranchWeightMD32(WeightMD, Weights);
if (Weights.size() != 2)
return;
uint32_t OrigLoopExitWeight = Weights[0];
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55bbffb18879fb..a425e26d490e4f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
static void GetBranchWeights(Instruction *TI,
SmallVectorImpl<uint64_t> &Weights) {
MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
- assert(MD);
- for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
- ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
- Weights.push_back(CI->getValue().getZExtValue());
- }
+ assert(MD && "Invalid branch-weight metadata");
+ extractFromBranchWeightMD64(MD, Weights);
// If TI is a conditional eq, the default case is the false case,
// and the corresponding branch-weight data is at index 2. We swap the
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Created using spr 1.3.4
Since some places, like SimplifyCFG work with 64-bit weights, we supply an API in ProfDataUtils to extract the weights accordingly. We change the API slightly to disambiguate the 64 bit version from the 32 bit version. Pull Request: llvm#86607
can you explain when 32-bit vs 64-bit weights are used? |
Be careful here! I think there is a bunch of code that is summing up weights and stores intermediate results in a
|
I am assuming that this also serves as preparation to increasing the bit size of the weight annotations? |
I think we generally work w/ 32-bit weights, but in some cases, we use 64-bit to make sure that summing or scaling don't overflow. @MatzeB mentioned some of this in his comment. My motivation here is to avoid having bespoke handling of branch weight extraction. To a large extent due to to #86609.
Right, I don't want to change the defaults, I just want to provide better utilities, so people aren't manually walking the MD_prof metadata. When we add the proposed optional field in #86609, the offsets won't be fixed, and IMO, its better to provide the necessary utilities and point people towards those. That's one of the reasons I left
That isn't one of my goals. |
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, if this is just for convenience / consistency then fine with me. LGTM, thanks
Since some places, like SimplifyCFG work with 64-bit weights, we supply an API
in ProfDataUtils to extract the weights accordingly.
We change the API slightly to disambiguate the 64 bit version from the 32 bit
version.