8
8
//
9
9
10
10
#include " llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11
+ #include " llvm/Analysis/OptimizationRemarkEmitter.h"
12
+ #include " llvm/IR/DiagnosticInfo.h"
13
+ #include " llvm/IR/IRBuilder.h"
14
+ #include " llvm/IR/Instructions.h"
15
+ #include " llvm/IR/IntrinsicInst.h"
16
+ #include " llvm/IR/PassManager.h"
11
17
#include " llvm/Support/CommandLine.h"
18
+ #include < utility>
12
19
13
20
using namespace llvm ;
14
21
22
+ #define DEBUG_TYPE " ctx-profile-lower"
23
+
15
24
static cl::list<std::string> ContextRoots (
16
25
" profile-context-root" , cl::Hidden,
17
26
cl::desc (
@@ -22,3 +31,295 @@ static cl::list<std::string> ContextRoots(
22
31
bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled () {
23
32
return !ContextRoots.empty ();
24
33
}
34
+
35
+ // the names of symbols we expect in compiler-rt. Using a namespace for
36
+ // readability.
37
+ namespace CompilerRtAPINames {
38
+ static auto StartCtx = " __llvm_ctx_profile_start_context" ;
39
+ static auto ReleaseCtx = " __llvm_ctx_profile_release_context" ;
40
+ static auto GetCtx = " __llvm_ctx_profile_get_context" ;
41
+ static auto ExpectedCalleeTLS = " __llvm_ctx_profile_expected_callee" ;
42
+ static auto CallsiteTLS = " __llvm_ctx_profile_callsite" ;
43
+ } // namespace CompilerRtAPINames
44
+
45
+ namespace {
46
+ // The lowering logic and state.
47
+ class CtxInstrumentationLowerer final {
48
+ Module &M;
49
+ ModuleAnalysisManager &MAM;
50
+ Type *ContextNodeTy = nullptr ;
51
+ Type *ContextRootTy = nullptr ;
52
+
53
+ DenseMap<const Function *, Constant *> ContextRootMap;
54
+ Function *StartCtx = nullptr ;
55
+ Function *GetCtx = nullptr ;
56
+ Function *ReleaseCtx = nullptr ;
57
+ GlobalVariable *ExpectedCalleeTLS = nullptr ;
58
+ GlobalVariable *CallsiteInfoTLS = nullptr ;
59
+
60
+ public:
61
+ CtxInstrumentationLowerer (Module &M, ModuleAnalysisManager &MAM);
62
+ void lowerFunction (Function &F);
63
+ };
64
+
65
+ std::pair<uint32_t , uint32_t > getNrCountersAndCallsites (const Function &F) {
66
+ uint32_t NrCounters = 0 ;
67
+ uint32_t NrCallsites = 0 ;
68
+ for (const auto &BB : F) {
69
+ for (const auto &I : BB) {
70
+ if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
71
+ if (!NrCounters)
72
+ NrCounters =
73
+ static_cast <uint32_t >(Incr->getNumCounters ()->getZExtValue ());
74
+ } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
75
+ if (!NrCallsites)
76
+ NrCallsites =
77
+ static_cast <uint32_t >(CSIntr->getNumCounters ()->getZExtValue ());
78
+ }
79
+ if (NrCounters && NrCallsites)
80
+ return std::make_pair (NrCounters, NrCallsites);
81
+ }
82
+ }
83
+ return {0 , 0 };
84
+ }
85
+ } // namespace
86
+
87
+ // set up tie-in with compiler-rt.
88
+ // NOTE!!!
89
+ // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
90
+ CtxInstrumentationLowerer::CtxInstrumentationLowerer (Module &M,
91
+ ModuleAnalysisManager &MAM)
92
+ : M(M), MAM(MAM) {
93
+ auto *PointerTy = PointerType::get (M.getContext (), 0 );
94
+ auto *SanitizerMutexType = Type::getInt8Ty (M.getContext ());
95
+ auto *I32Ty = Type::getInt32Ty (M.getContext ());
96
+ auto *I64Ty = Type::getInt64Ty (M.getContext ());
97
+
98
+ // The ContextRoot type
99
+ ContextRootTy =
100
+ StructType::get (M.getContext (), {
101
+ PointerTy, /* FirstNode*/
102
+ PointerTy, /* FirstMemBlock*/
103
+ PointerTy, /* CurrentMem*/
104
+ SanitizerMutexType, /* Taken*/
105
+ });
106
+ // The Context header.
107
+ ContextNodeTy = StructType::get (M.getContext (), {
108
+ I64Ty, /* Guid*/
109
+ PointerTy, /* Next*/
110
+ I32Ty, /* NrCounters*/
111
+ I32Ty, /* NrCallsites*/
112
+ });
113
+
114
+ // Define a global for each entrypoint. We'll reuse the entrypoint's name as
115
+ // prefix. We assume the entrypoint names to be unique.
116
+ for (const auto &Fname : ContextRoots) {
117
+ if (const auto *F = M.getFunction (Fname)) {
118
+ if (F->isDeclaration ())
119
+ continue ;
120
+ auto *G = M.getOrInsertGlobal (Fname + " _ctx_root" , ContextRootTy);
121
+ cast<GlobalVariable>(G)->setInitializer (
122
+ Constant::getNullValue (ContextRootTy));
123
+ ContextRootMap.insert (std::make_pair (F, G));
124
+ }
125
+ }
126
+
127
+ // Declare the functions we will call.
128
+ StartCtx = cast<Function>(
129
+ M.getOrInsertFunction (
130
+ CompilerRtAPINames::StartCtx,
131
+ FunctionType::get (ContextNodeTy->getPointerTo (),
132
+ {ContextRootTy->getPointerTo (), /* ContextRoot*/
133
+ I64Ty, /* Guid*/ I32Ty,
134
+ /* NrCounters*/ I32Ty /* NrCallsites*/ },
135
+ false ))
136
+ .getCallee ());
137
+ GetCtx = cast<Function>(
138
+ M.getOrInsertFunction (CompilerRtAPINames::GetCtx,
139
+ FunctionType::get (ContextNodeTy->getPointerTo (),
140
+ {PointerTy, /* Callee*/
141
+ I64Ty, /* Guid*/
142
+ I32Ty, /* NrCounters*/
143
+ I32Ty}, /* NrCallsites*/
144
+ false ))
145
+ .getCallee ());
146
+ ReleaseCtx = cast<Function>(
147
+ M.getOrInsertFunction (
148
+ CompilerRtAPINames::ReleaseCtx,
149
+ FunctionType::get (Type::getVoidTy (M.getContext ()),
150
+ {
151
+ ContextRootTy->getPointerTo (), /* ContextRoot*/
152
+ },
153
+ false ))
154
+ .getCallee ());
155
+
156
+ // Declare the TLSes we will need to use.
157
+ CallsiteInfoTLS =
158
+ new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
159
+ nullptr , CompilerRtAPINames::CallsiteTLS);
160
+ CallsiteInfoTLS->setThreadLocal (true );
161
+ CallsiteInfoTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
162
+ ExpectedCalleeTLS =
163
+ new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
164
+ nullptr , CompilerRtAPINames::ExpectedCalleeTLS);
165
+ ExpectedCalleeTLS->setThreadLocal (true );
166
+ ExpectedCalleeTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
167
+ }
168
+
169
+ PreservedAnalyses PGOCtxProfLoweringPass::run (Module &M,
170
+ ModuleAnalysisManager &MAM) {
171
+ CtxInstrumentationLowerer Lowerer (M, MAM);
172
+ for (auto &F : M)
173
+ Lowerer.lowerFunction (F);
174
+ return PreservedAnalyses::none ();
175
+ }
176
+
177
+ void CtxInstrumentationLowerer::lowerFunction (Function &F) {
178
+ if (F.isDeclaration ())
179
+ return ;
180
+ auto &FAM = MAM.getResult <FunctionAnalysisManagerModuleProxy>(M).getManager ();
181
+ auto &ORE = FAM.getResult <OptimizationRemarkEmitterAnalysis>(F);
182
+
183
+ Value *Guid = nullptr ;
184
+ auto [NrCounters, NrCallsites] = getNrCountersAndCallsites (F);
185
+
186
+ Value *Context = nullptr ;
187
+ Value *RealContext = nullptr ;
188
+
189
+ StructType *ThisContextType = nullptr ;
190
+ Value *TheRootContext = nullptr ;
191
+ Value *ExpectedCalleeTLSAddr = nullptr ;
192
+ Value *CallsiteInfoTLSAddr = nullptr ;
193
+
194
+ auto &Head = F.getEntryBlock ();
195
+ for (auto &I : Head) {
196
+ // Find the increment intrinsic in the entry basic block.
197
+ if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
198
+ assert (Mark->getIndex ()->isZero ());
199
+
200
+ IRBuilder<> Builder (Mark);
201
+ // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
202
+ Guid = Builder.getInt64 (F.getGUID ());
203
+ // The type of the context of this function is now knowable since we have
204
+ // NrCallsites and NrCounters. We delcare it here because it's more
205
+ // convenient - we have the Builder.
206
+ ThisContextType = StructType::get (
207
+ F.getContext (),
208
+ {ContextNodeTy, ArrayType::get (Builder.getInt64Ty (), NrCounters),
209
+ ArrayType::get (Builder.getPtrTy (), NrCallsites)});
210
+ // Figure out which way we obtain the context object for this function -
211
+ // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
212
+ // former case, we also set TheRootContext since we need it to release it
213
+ // at the end (plus it can be used to know if we have an entrypoint or a
214
+ // regular function)
215
+ auto Iter = ContextRootMap.find (&F);
216
+ if (Iter != ContextRootMap.end ()) {
217
+ TheRootContext = Iter->second ;
218
+ Context = Builder.CreateCall (StartCtx, {TheRootContext, Guid,
219
+ Builder.getInt32 (NrCounters),
220
+ Builder.getInt32 (NrCallsites)});
221
+ ORE.emit (
222
+ [&] { return OptimizationRemark (DEBUG_TYPE, " Entrypoint" , &F); });
223
+ } else {
224
+ Context =
225
+ Builder.CreateCall (GetCtx, {&F, Guid, Builder.getInt32 (NrCounters),
226
+ Builder.getInt32 (NrCallsites)});
227
+ ORE.emit ([&] {
228
+ return OptimizationRemark (DEBUG_TYPE, " RegularFunction" , &F);
229
+ });
230
+ }
231
+ // The context could be scratch.
232
+ auto *CtxAsInt = Builder.CreatePtrToInt (Context, Builder.getInt64Ty ());
233
+ if (NrCallsites > 0 ) {
234
+ // Figure out which index of the TLS 2-element buffers to use.
235
+ // Scratch context => we use index == 1. Real contexts => index == 0.
236
+ auto *Index = Builder.CreateAnd (CtxAsInt, Builder.getInt64 (1 ));
237
+ // The GEPs corresponding to that index, in the respective TLS.
238
+ ExpectedCalleeTLSAddr = Builder.CreateGEP (
239
+ Builder.getInt8Ty ()->getPointerTo (),
240
+ Builder.CreateThreadLocalAddress (ExpectedCalleeTLS), {Index});
241
+ CallsiteInfoTLSAddr = Builder.CreateGEP (
242
+ Builder.getInt32Ty (),
243
+ Builder.CreateThreadLocalAddress (CallsiteInfoTLS), {Index});
244
+ }
245
+ // Because the context pointer may have LSB set (to indicate scratch),
246
+ // clear it for the value we use as base address for the counter vector.
247
+ // This way, if later we want to have "real" (not clobbered) buffers
248
+ // acting as scratch, the lowering (at least this part of it that deals
249
+ // with counters) stays the same.
250
+ RealContext = Builder.CreateIntToPtr (
251
+ Builder.CreateAnd (CtxAsInt, Builder.getInt64 (-2 )),
252
+ ThisContextType->getPointerTo ());
253
+ I.eraseFromParent ();
254
+ break ;
255
+ }
256
+ }
257
+ if (!Context) {
258
+ ORE.emit ([&] {
259
+ return OptimizationRemarkMissed (DEBUG_TYPE, " Skip" , &F)
260
+ << " Function doesn't have instrumentation, skipping" ;
261
+ });
262
+ return ;
263
+ }
264
+
265
+ bool ContextWasReleased = false ;
266
+ for (auto &BB : F) {
267
+ for (auto &I : llvm::make_early_inc_range (BB)) {
268
+ if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
269
+ IRBuilder<> Builder (Instr);
270
+ switch (Instr->getIntrinsicID ()) {
271
+ case llvm::Intrinsic::instrprof_increment:
272
+ case llvm::Intrinsic::instrprof_increment_step: {
273
+ // Increments (or increment-steps) are just a typical load - increment
274
+ // - store in the RealContext.
275
+ auto *AsStep = cast<InstrProfIncrementInst>(Instr);
276
+ auto *GEP = Builder.CreateGEP (
277
+ ThisContextType, RealContext,
278
+ {Builder.getInt32 (0 ), Builder.getInt32 (1 ), AsStep->getIndex ()});
279
+ Builder.CreateStore (
280
+ Builder.CreateAdd (Builder.CreateLoad (Builder.getInt64Ty (), GEP),
281
+ AsStep->getStep ()),
282
+ GEP);
283
+ } break ;
284
+ case llvm::Intrinsic::instrprof_callsite:
285
+ // callsite lowering: write the called value in the expected callee
286
+ // TLS we treat the TLS as volatile because of signal handlers and to
287
+ // avoid these being moved away from the callsite they decorate.
288
+ auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
289
+ Builder.CreateStore (CSIntrinsic->getCallee (), ExpectedCalleeTLSAddr,
290
+ true );
291
+ // write the GEP of the slot in the sub-contexts portion of the
292
+ // context in TLS. Now, here, we use the actual Context value - as
293
+ // returned from compiler-rt - which may have the LSB set if the
294
+ // Context was scratch. Since the header of the context object and
295
+ // then the values are all 8-aligned (or, really, insofar as we care,
296
+ // they are even) - if the context is scratch (meaning, an odd value),
297
+ // so will the GEP. This is important because this is then visible to
298
+ // compiler-rt which will produce scratch contexts for callers that
299
+ // have a scratch context.
300
+ Builder.CreateStore (
301
+ Builder.CreateGEP (ThisContextType, Context,
302
+ {Builder.getInt32 (0 ), Builder.getInt32 (2 ),
303
+ CSIntrinsic->getIndex ()}),
304
+ CallsiteInfoTLSAddr, true );
305
+ break ;
306
+ }
307
+ I.eraseFromParent ();
308
+ } else if (TheRootContext && isa<ReturnInst>(I)) {
309
+ // Remember to release the context if we are an entrypoint.
310
+ IRBuilder<> Builder (&I);
311
+ Builder.CreateCall (ReleaseCtx, {TheRootContext});
312
+ ContextWasReleased = true ;
313
+ }
314
+ }
315
+ }
316
+ // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
317
+ // to disallow this, (so this then stays as an error), another is to detect
318
+ // that and then do a wrapper or disallow the tail call. This only affects
319
+ // instrumentation, when we want to detect the call graph.
320
+ if (TheRootContext && !ContextWasReleased)
321
+ F.getContext ().emitError (
322
+ " [ctx_prof] An entrypoint was instrumented but it has no `ret` "
323
+ " instructions above which to release the context: " +
324
+ F.getName ());
325
+ }
0 commit comments