110
110
#include " llvm/Transforms/Instrumentation.h"
111
111
#include " llvm/Transforms/Instrumentation/BlockCoverageInference.h"
112
112
#include " llvm/Transforms/Instrumentation/CFGMST.h"
113
+ #include " llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
113
114
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
114
115
#include " llvm/Transforms/Utils/MisExpect.h"
115
116
#include " llvm/Transforms/Utils/ModuleUtils.h"
@@ -333,6 +334,20 @@ extern cl::opt<bool> EnableVTableValueProfiling;
333
334
extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
334
335
} // namespace llvm
335
336
337
+ bool shouldInstrumentEntryBB () {
338
+ return PGOInstrumentEntry ||
339
+ PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
340
+ }
341
+
342
+ // FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx
343
+ // profiling implicitly captures indirect call cases, but not other values.
344
+ // Supporting other values is relatively straight-forward - just another counter
345
+ // range within the context.
346
+ bool isValueProfilingDisabled () {
347
+ return DisableValueProfiling ||
348
+ PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
349
+ }
350
+
336
351
// Return a string describing the branch condition that can be
337
352
// used in static branch probability heuristics:
338
353
static std::string getBranchCondString (Instruction *TI) {
@@ -379,7 +394,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
379
394
uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
380
395
if (IsCS)
381
396
ProfileVersion |= VARIANT_MASK_CSIR_PROF;
382
- if (PGOInstrumentEntry )
397
+ if (shouldInstrumentEntryBB () )
383
398
ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
384
399
if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
385
400
ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -861,7 +876,7 @@ static void instrumentOneFunc(
861
876
}
862
877
863
878
FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo (
864
- F, TLI, ComdatMembers, true , BPI, BFI, IsCS, PGOInstrumentEntry ,
879
+ F, TLI, ComdatMembers, true , BPI, BFI, IsCS, shouldInstrumentEntryBB () ,
865
880
PGOBlockCoverage);
866
881
867
882
auto Name = FuncInfo.FuncNameVar ;
@@ -883,6 +898,43 @@ static void instrumentOneFunc(
883
898
unsigned NumCounters =
884
899
InstrumentBBs.size () + FuncInfo.SIVisitor .getNumOfSelectInsts ();
885
900
901
+ if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled ()) {
902
+ auto *CSIntrinsic =
903
+ Intrinsic::getDeclaration (M, Intrinsic::instrprof_callsite);
904
+ // We want to count the instrumentable callsites, then instrument them. This
905
+ // is because the llvm.instrprof.callsite intrinsic has an argument (like
906
+ // the other instrprof intrinsics) capturing the total number of
907
+ // instrumented objects (counters, or callsites, in this case). In this
908
+ // case, we want that value so we can readily pass it to the compiler-rt
909
+ // APIs that may have to allocate memory based on the nr of callsites.
910
+ // The traversal logic is the same for both counting and instrumentation,
911
+ // just needs to be done in succession.
912
+ auto Visit = [&](llvm::function_ref<void (CallBase * CB)> Visitor) {
913
+ for (auto &BB : F)
914
+ for (auto &Instr : BB)
915
+ if (auto *CS = dyn_cast<CallBase>(&Instr)) {
916
+ if ((CS->getCalledFunction () &&
917
+ CS->getCalledFunction ()->isIntrinsic ()) ||
918
+ dyn_cast<InlineAsm>(CS->getCalledOperand ()))
919
+ continue ;
920
+ Visitor (CS);
921
+ }
922
+ };
923
+ // First, count callsites.
924
+ uint32_t TotalNrCallsites = 0 ;
925
+ Visit ([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
926
+
927
+ // Now instrument.
928
+ uint32_t CallsiteIndex = 0 ;
929
+ Visit ([&](auto *CB) {
930
+ IRBuilder<> Builder (CB);
931
+ Builder.CreateCall (CSIntrinsic,
932
+ {Name, CFGHash, Builder.getInt32 (TotalNrCallsites),
933
+ Builder.getInt32 (CallsiteIndex++),
934
+ CB->getCalledOperand ()});
935
+ });
936
+ }
937
+
886
938
uint32_t I = 0 ;
887
939
if (PGOTemporalInstrumentation) {
888
940
NumCounters += PGOBlockCoverage ? 8 : 1 ;
@@ -914,7 +966,7 @@ static void instrumentOneFunc(
914
966
FuncInfo.FunctionHash );
915
967
assert (I == NumCounters);
916
968
917
- if (DisableValueProfiling )
969
+ if (isValueProfilingDisabled () )
918
970
return ;
919
971
920
972
NumOfPGOICall += FuncInfo.ValueSites [IPVK_IndirectCallTarget].size ();
@@ -1676,7 +1728,7 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
1676
1728
1677
1729
// Traverse all valuesites and annotate the instructions for all value kind.
1678
1730
void PGOUseFunc::annotateValueSites () {
1679
- if (DisableValueProfiling )
1731
+ if (isValueProfilingDisabled () )
1680
1732
return ;
1681
1733
1682
1734
// Create the PGOFuncName meta data.
@@ -1779,7 +1831,7 @@ static bool InstrumentAllFunctions(
1779
1831
function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
1780
1832
// For the context-sensitve instrumentation, we should have a separated pass
1781
1833
// (before LTO/ThinLTO linking) to create these variables.
1782
- if (!IsCS)
1834
+ if (!IsCS && ! PGOCtxProfLoweringPass::isContextualIRPGOEnabled () )
1783
1835
createIRLevelProfileFlagVar (M, /* IsCS=*/ false );
1784
1836
1785
1837
Triple TT (M.getTargetTriple ());
@@ -2018,6 +2070,8 @@ static bool annotateAllFunctions(
2018
2070
bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled ();
2019
2071
if (PGOInstrumentEntry.getNumOccurrences () > 0 )
2020
2072
InstrumentFuncEntry = PGOInstrumentEntry;
2073
+ InstrumentFuncEntry |= PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
2074
+
2021
2075
bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage ();
2022
2076
for (auto &F : M) {
2023
2077
if (skipPGOUse (F))
0 commit comments