Skip to content

Commit 8aaef96

Browse files
committed
[llvm][ctx_profile] Add instrumentation lowering
This adds the instrumentation lowering pass. (Tracking Issue: llvm#89287, RFC referenced there)
1 parent e98cb36 commit 8aaef96

File tree

6 files changed

+473
-1
lines changed

6 files changed

+473
-1
lines changed

llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
1313
#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
1414

15+
#include "llvm/IR/PassManager.h"
1516
namespace llvm {
1617
class Type;
1718

18-
class PGOCtxProfLoweringPass {
19+
class PGOCtxProfLoweringPass : public PassInfoMixin<PGOCtxProfLoweringPass> {
1920
public:
2021
explicit PGOCtxProfLoweringPass() = default;
2122
static bool isContextualIRPGOEnabled();
23+
24+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
2225
};
2326
} // namespace llvm
2427
#endif

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
#include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h"
176176
#include "llvm/Transforms/Instrumentation/MemProfiler.h"
177177
#include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
178+
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
178179
#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
179180
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
180181
#include "llvm/Transforms/Instrumentation/PoisonChecking.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include "llvm/Transforms/Instrumentation/InstrOrderFile.h"
7575
#include "llvm/Transforms/Instrumentation/InstrProfiling.h"
7676
#include "llvm/Transforms/Instrumentation/MemProfiler.h"
77+
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
7778
#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
7879
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
7980
#include "llvm/Transforms/Scalar/ADCE.h"
@@ -834,6 +835,10 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
834835
PTO.EagerlyInvalidateAnalyses));
835836
}
836837

838+
if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) {
839+
MPM.addPass(PGOCtxProfLoweringPass());
840+
return;
841+
}
837842
// Add the profile lowering pass.
838843
InstrProfOptions Options;
839844
if (!ProfileFile.empty())

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ MODULE_PASS("inliner-wrapper-no-mandatory-first",
7777
MODULE_PASS("insert-gcov-profiling", GCOVProfilerPass())
7878
MODULE_PASS("instrorderfile", InstrOrderFilePass())
7979
MODULE_PASS("instrprof", InstrProfilingLoweringPass())
80+
MODULE_PASS("pgo-ctx-instr-lower", PGOCtxProfLoweringPass())
8081
MODULE_PASS("internalize", InternalizePass())
8182
MODULE_PASS("invalidate<all>", InvalidateAllAnalysesPass())
8283
MODULE_PASS("iroutliner", IROutlinerPass())

llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88
//
99

1010
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
12+
#include "llvm/IR/DiagnosticInfo.h"
13+
#include "llvm/IR/IRBuilder.h"
14+
#include "llvm/IR/Instructions.h"
15+
#include "llvm/IR/IntrinsicInst.h"
16+
#include "llvm/IR/PassManager.h"
1117
#include "llvm/Support/CommandLine.h"
18+
#include <utility>
1219

1320
using namespace llvm;
1421

22+
#define DEBUG_TYPE "ctx-profile-lower"
23+
1524
static cl::list<std::string> ContextRoots(
1625
"profile-context-root", cl::Hidden,
1726
cl::desc(
@@ -22,3 +31,295 @@ static cl::list<std::string> ContextRoots(
2231
bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
2332
return !ContextRoots.empty();
2433
}
34+
35+
// the names of symbols we expect in compiler-rt. Using a namespace for
36+
// readability.
37+
namespace CompilerRtAPINames {
38+
static auto StartCtx = "__llvm_ctx_profile_start_context";
39+
static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
40+
static auto GetCtx = "__llvm_ctx_profile_get_context";
41+
static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
42+
static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
43+
} // namespace CompilerRtAPINames
44+
45+
namespace {
46+
// The lowering logic and state.
47+
class CtxInstrumentationLowerer final {
48+
Module &M;
49+
ModuleAnalysisManager &MAM;
50+
Type *ContextNodeTy = nullptr;
51+
Type *ContextRootTy = nullptr;
52+
53+
DenseMap<const Function *, Constant *> ContextRootMap;
54+
Function *StartCtx = nullptr;
55+
Function *GetCtx = nullptr;
56+
Function *ReleaseCtx = nullptr;
57+
GlobalVariable *ExpectedCalleeTLS = nullptr;
58+
GlobalVariable *CallsiteInfoTLS = nullptr;
59+
60+
public:
61+
CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
62+
void lowerFunction(Function &F);
63+
};
64+
65+
std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
66+
uint32_t NrCounters = 0;
67+
uint32_t NrCallsites = 0;
68+
for (const auto &BB : F) {
69+
for (const auto &I : BB) {
70+
if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
71+
if (!NrCounters)
72+
NrCounters =
73+
static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
74+
} else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
75+
if (!NrCallsites)
76+
NrCallsites =
77+
static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
78+
}
79+
if (NrCounters && NrCallsites)
80+
return std::make_pair(NrCounters, NrCallsites);
81+
}
82+
}
83+
return {0, 0};
84+
}
85+
} // namespace
86+
87+
// set up tie-in with compiler-rt.
88+
// NOTE!!!
89+
// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
90+
CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
91+
ModuleAnalysisManager &MAM)
92+
: M(M), MAM(MAM) {
93+
auto *PointerTy = PointerType::get(M.getContext(), 0);
94+
auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
95+
auto *I32Ty = Type::getInt32Ty(M.getContext());
96+
auto *I64Ty = Type::getInt64Ty(M.getContext());
97+
98+
// The ContextRoot type
99+
ContextRootTy =
100+
StructType::get(M.getContext(), {
101+
PointerTy, /*FirstNode*/
102+
PointerTy, /*FirstMemBlock*/
103+
PointerTy, /*CurrentMem*/
104+
SanitizerMutexType, /*Taken*/
105+
});
106+
// The Context header.
107+
ContextNodeTy = StructType::get(M.getContext(), {
108+
I64Ty, /*Guid*/
109+
PointerTy, /*Next*/
110+
I32Ty, /*NrCounters*/
111+
I32Ty, /*NrCallsites*/
112+
});
113+
114+
// Define a global for each entrypoint. We'll reuse the entrypoint's name as
115+
// prefix. We assume the entrypoint names to be unique.
116+
for (const auto &Fname : ContextRoots) {
117+
if (const auto *F = M.getFunction(Fname)) {
118+
if (F->isDeclaration())
119+
continue;
120+
auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
121+
cast<GlobalVariable>(G)->setInitializer(
122+
Constant::getNullValue(ContextRootTy));
123+
ContextRootMap.insert(std::make_pair(F, G));
124+
}
125+
}
126+
127+
// Declare the functions we will call.
128+
StartCtx = cast<Function>(
129+
M.getOrInsertFunction(
130+
CompilerRtAPINames::StartCtx,
131+
FunctionType::get(ContextNodeTy->getPointerTo(),
132+
{ContextRootTy->getPointerTo(), /*ContextRoot*/
133+
I64Ty, /*Guid*/ I32Ty,
134+
/*NrCounters*/ I32Ty /*NrCallsites*/},
135+
false))
136+
.getCallee());
137+
GetCtx = cast<Function>(
138+
M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
139+
FunctionType::get(ContextNodeTy->getPointerTo(),
140+
{PointerTy, /*Callee*/
141+
I64Ty, /*Guid*/
142+
I32Ty, /*NrCounters*/
143+
I32Ty}, /*NrCallsites*/
144+
false))
145+
.getCallee());
146+
ReleaseCtx = cast<Function>(
147+
M.getOrInsertFunction(
148+
CompilerRtAPINames::ReleaseCtx,
149+
FunctionType::get(Type::getVoidTy(M.getContext()),
150+
{
151+
ContextRootTy->getPointerTo(), /*ContextRoot*/
152+
},
153+
false))
154+
.getCallee());
155+
156+
// Declare the TLSes we will need to use.
157+
CallsiteInfoTLS =
158+
new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
159+
nullptr, CompilerRtAPINames::CallsiteTLS);
160+
CallsiteInfoTLS->setThreadLocal(true);
161+
CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
162+
ExpectedCalleeTLS =
163+
new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
164+
nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
165+
ExpectedCalleeTLS->setThreadLocal(true);
166+
ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
167+
}
168+
169+
PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
170+
ModuleAnalysisManager &MAM) {
171+
CtxInstrumentationLowerer Lowerer(M, MAM);
172+
for (auto &F : M)
173+
Lowerer.lowerFunction(F);
174+
return PreservedAnalyses::none();
175+
}
176+
177+
void CtxInstrumentationLowerer::lowerFunction(Function &F) {
178+
if (F.isDeclaration())
179+
return;
180+
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
181+
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
182+
183+
Value *Guid = nullptr;
184+
auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
185+
186+
Value *Context = nullptr;
187+
Value *RealContext = nullptr;
188+
189+
StructType *ThisContextType = nullptr;
190+
Value *TheRootContext = nullptr;
191+
Value *ExpectedCalleeTLSAddr = nullptr;
192+
Value *CallsiteInfoTLSAddr = nullptr;
193+
194+
auto &Head = F.getEntryBlock();
195+
for (auto &I : Head) {
196+
// Find the increment intrinsic in the entry basic block.
197+
if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
198+
assert(Mark->getIndex()->isZero());
199+
200+
IRBuilder<> Builder(Mark);
201+
// FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
202+
Guid = Builder.getInt64(F.getGUID());
203+
// The type of the context of this function is now knowable since we have
204+
// NrCallsites and NrCounters. We delcare it here because it's more
205+
// convenient - we have the Builder.
206+
ThisContextType = StructType::get(
207+
F.getContext(),
208+
{ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
209+
ArrayType::get(Builder.getPtrTy(), NrCallsites)});
210+
// Figure out which way we obtain the context object for this function -
211+
// if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
212+
// former case, we also set TheRootContext since we need it to release it
213+
// at the end (plus it can be used to know if we have an entrypoint or a
214+
// regular function)
215+
auto Iter = ContextRootMap.find(&F);
216+
if (Iter != ContextRootMap.end()) {
217+
TheRootContext = Iter->second;
218+
Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
219+
Builder.getInt32(NrCounters),
220+
Builder.getInt32(NrCallsites)});
221+
ORE.emit(
222+
[&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
223+
} else {
224+
Context =
225+
Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
226+
Builder.getInt32(NrCallsites)});
227+
ORE.emit([&] {
228+
return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
229+
});
230+
}
231+
// The context could be scratch.
232+
auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
233+
if (NrCallsites > 0) {
234+
// Figure out which index of the TLS 2-element buffers to use.
235+
// Scratch context => we use index == 1. Real contexts => index == 0.
236+
auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
237+
// The GEPs corresponding to that index, in the respective TLS.
238+
ExpectedCalleeTLSAddr = Builder.CreateGEP(
239+
Builder.getInt8Ty()->getPointerTo(),
240+
Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
241+
CallsiteInfoTLSAddr = Builder.CreateGEP(
242+
Builder.getInt32Ty(),
243+
Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
244+
}
245+
// Because the context pointer may have LSB set (to indicate scratch),
246+
// clear it for the value we use as base address for the counter vector.
247+
// This way, if later we want to have "real" (not clobbered) buffers
248+
// acting as scratch, the lowering (at least this part of it that deals
249+
// with counters) stays the same.
250+
RealContext = Builder.CreateIntToPtr(
251+
Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
252+
ThisContextType->getPointerTo());
253+
I.eraseFromParent();
254+
break;
255+
}
256+
}
257+
if (!Context) {
258+
ORE.emit([&] {
259+
return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
260+
<< "Function doesn't have instrumentation, skipping";
261+
});
262+
return;
263+
}
264+
265+
bool ContextWasReleased = false;
266+
for (auto &BB : F) {
267+
for (auto &I : llvm::make_early_inc_range(BB)) {
268+
if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
269+
IRBuilder<> Builder(Instr);
270+
switch (Instr->getIntrinsicID()) {
271+
case llvm::Intrinsic::instrprof_increment:
272+
case llvm::Intrinsic::instrprof_increment_step: {
273+
// Increments (or increment-steps) are just a typical load - increment
274+
// - store in the RealContext.
275+
auto *AsStep = cast<InstrProfIncrementInst>(Instr);
276+
auto *GEP = Builder.CreateGEP(
277+
ThisContextType, RealContext,
278+
{Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
279+
Builder.CreateStore(
280+
Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
281+
AsStep->getStep()),
282+
GEP);
283+
} break;
284+
case llvm::Intrinsic::instrprof_callsite:
285+
// callsite lowering: write the called value in the expected callee
286+
// TLS we treat the TLS as volatile because of signal handlers and to
287+
// avoid these being moved away from the callsite they decorate.
288+
auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
289+
Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
290+
true);
291+
// write the GEP of the slot in the sub-contexts portion of the
292+
// context in TLS. Now, here, we use the actual Context value - as
293+
// returned from compiler-rt - which may have the LSB set if the
294+
// Context was scratch. Since the header of the context object and
295+
// then the values are all 8-aligned (or, really, insofar as we care,
296+
// they are even) - if the context is scratch (meaning, an odd value),
297+
// so will the GEP. This is important because this is then visible to
298+
// compiler-rt which will produce scratch contexts for callers that
299+
// have a scratch context.
300+
Builder.CreateStore(
301+
Builder.CreateGEP(ThisContextType, Context,
302+
{Builder.getInt32(0), Builder.getInt32(2),
303+
CSIntrinsic->getIndex()}),
304+
CallsiteInfoTLSAddr, true);
305+
break;
306+
}
307+
I.eraseFromParent();
308+
} else if (TheRootContext && isa<ReturnInst>(I)) {
309+
// Remember to release the context if we are an entrypoint.
310+
IRBuilder<> Builder(&I);
311+
Builder.CreateCall(ReleaseCtx, {TheRootContext});
312+
ContextWasReleased = true;
313+
}
314+
}
315+
}
316+
// FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
317+
// to disallow this, (so this then stays as an error), another is to detect
318+
// that and then do a wrapper or disallow the tail call. This only affects
319+
// instrumentation, when we want to detect the call graph.
320+
if (TheRootContext && !ContextWasReleased)
321+
F.getContext().emitError(
322+
"[ctx_prof] An entrypoint was instrumented but it has no `ret` "
323+
"instructions above which to release the context: " +
324+
F.getName());
325+
}

0 commit comments

Comments
 (0)