Skip to content

Commit 9050b27

Browse files
authored
[OpenMPIRBuilder] Remove wrapper function in createTask, createTeams (#67723)
This patch removes the wrapper function in `OpenMPIRBuilder::createTask` and `OpenMPIRBuilder.createTeams`. The outlined function is directly of the form that is expected by the runtime library calls. This patch also adds a utility function to help add fake values and their uses, which will be deleted in finalization callbacks. **Why we needed wrappers earlier?** Before the post outline callbacks are executed, the IR has the following structure: ``` define @func() { ;... call void @outlined_fn(ptr %data) ;... } define void @outlined_fn(ptr %data) ``` OpenMP offloading expects a specific signature for the outlined function in a runtime call. For example, `__kmpc_fork_teams` expects the following signature: ``` define @outlined_fn(ptr %global.tid, ptr %data) ``` As there is no way to change a function's arguments after it has been created, a wrapper function with the expected signature is created that calls the outlined function inside it. **How we are handling it now?** To handle this in the current patch, we create a "fake" global tid and add a "fake" use for it in the to-be-outlined region. We need to create these fake values so the outliner sees it as something it needs to pass to the outlined function. We also tell the outliner to exclude this global tid value from the aggregate `data` argument, so it comes as a separate argument in the beginning. This way, we are able to directly get the outlined function in the expected format. This is inspired by the way `createParallel` handles outlining (using fake values and then deleting them later). Tasks are handled with a similar approach. This simplifies the generated code and the code to do this itself also becomes simpler (because we no longer have to construct a new function).
1 parent 171a3a6 commit 9050b27

File tree

4 files changed

+167
-205
lines changed

4 files changed

+167
-205
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 104 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,42 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
340340
return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
341341
}
342342

343+
// This function creates a fake integer value and a fake use for the integer
344+
// value. It returns the fake value created. This is useful in modeling the
345+
// extra arguments to the outlined functions.
346+
Value *createFakeIntVal(IRBuilder<> &Builder,
347+
OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
348+
std::stack<Instruction *> &ToBeDeleted,
349+
OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
350+
const Twine &Name = "", bool AsPtr = true) {
351+
Builder.restoreIP(OuterAllocaIP);
352+
Instruction *FakeVal;
353+
AllocaInst *FakeValAddr =
354+
Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
355+
ToBeDeleted.push(FakeValAddr);
356+
357+
if (AsPtr) {
358+
FakeVal = FakeValAddr;
359+
} else {
360+
FakeVal =
361+
Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
362+
ToBeDeleted.push(FakeVal);
363+
}
364+
365+
// Generate a fake use of this value
366+
Builder.restoreIP(InnerAllocaIP);
367+
Instruction *UseFakeVal;
368+
if (AsPtr) {
369+
UseFakeVal =
370+
Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
371+
} else {
372+
UseFakeVal =
373+
cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
374+
}
375+
ToBeDeleted.push(UseFakeVal);
376+
return FakeVal;
377+
}
378+
343379
//===----------------------------------------------------------------------===//
344380
// OpenMPIRBuilderConfig
345381
//===----------------------------------------------------------------------===//
@@ -1496,6 +1532,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
14961532
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
14971533
bool Tied, Value *Final, Value *IfCondition,
14981534
SmallVector<DependData> Dependencies) {
1535+
14991536
if (!updateToLocation(Loc))
15001537
return InsertPointTy();
15011538

@@ -1523,41 +1560,31 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15231560
BasicBlock *TaskAllocaBB =
15241561
splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
15251562

1563+
InsertPointTy TaskAllocaIP =
1564+
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1565+
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1566+
BodyGenCB(TaskAllocaIP, TaskBodyIP);
1567+
15261568
OutlineInfo OI;
15271569
OI.EntryBB = TaskAllocaBB;
15281570
OI.OuterAllocaBB = AllocaIP.getBlock();
15291571
OI.ExitBB = TaskExitBB;
1530-
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
1531-
Dependencies](Function &OutlinedFn) {
1532-
// The input IR here looks like the following-
1533-
// ```
1534-
// func @current_fn() {
1535-
// outlined_fn(%args)
1536-
// }
1537-
// func @outlined_fn(%args) { ... }
1538-
// ```
1539-
//
1540-
// This is changed to the following-
1541-
//
1542-
// ```
1543-
// func @current_fn() {
1544-
// runtime_call(..., wrapper_fn, ...)
1545-
// }
1546-
// func @wrapper_fn(..., %args) {
1547-
// outlined_fn(%args)
1548-
// }
1549-
// func @outlined_fn(%args) { ... }
1550-
// ```
15511572

1552-
// The stale call instruction will be replaced with a new call instruction
1553-
// for runtime call with a wrapper function.
1573+
// Add the thread ID argument.
1574+
std::stack<Instruction *> ToBeDeleted;
1575+
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1576+
Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1577+
1578+
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1579+
TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1580+
// Replace the Stale CI by appropriate RTL function call.
15541581
assert(OutlinedFn.getNumUses() == 1 &&
15551582
"there must be a single user for the outlined function");
15561583
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
15571584

15581585
// HasShareds is true if any variables are captured in the outlined region,
15591586
// false otherwise.
1560-
bool HasShareds = StaleCI->arg_size() > 0;
1587+
bool HasShareds = StaleCI->arg_size() > 1;
15611588
Builder.SetInsertPoint(StaleCI);
15621589

15631590
// Gather the arguments for emitting the runtime call for
@@ -1595,7 +1622,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15951622
Value *SharedsSize = Builder.getInt64(0);
15961623
if (HasShareds) {
15971624
AllocaInst *ArgStructAlloca =
1598-
dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
1625+
dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
15991626
assert(ArgStructAlloca &&
16001627
"Unable to find the alloca instruction corresponding to arguments "
16011628
"for extracted function");
@@ -1606,31 +1633,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
16061633
SharedsSize =
16071634
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
16081635
}
1609-
1610-
// Argument - task_entry (the wrapper function)
1611-
// If the outlined function has some captured variables (i.e. HasShareds is
1612-
// true), then the wrapper function will have an additional argument (the
1613-
// struct containing captured variables). Otherwise, no such argument will
1614-
// be present.
1615-
SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
1616-
if (HasShareds)
1617-
WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
1618-
FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
1619-
(Twine(OutlinedFn.getName()) + ".wrapper").str(),
1620-
FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
1621-
Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
1622-
16231636
// Emit the @__kmpc_omp_task_alloc runtime call
16241637
// The runtime call returns a pointer to an area where the task captured
16251638
// variables must be copied before the task is run (TaskData)
16261639
CallInst *TaskData = Builder.CreateCall(
16271640
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
16281641
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1629-
/*task_func=*/WrapperFunc});
1642+
/*task_func=*/&OutlinedFn});
16301643

16311644
// Copy the arguments for outlined function
16321645
if (HasShareds) {
1633-
Value *Shareds = StaleCI->getArgOperand(0);
1646+
Value *Shareds = StaleCI->getArgOperand(1);
16341647
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
16351648
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
16361649
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
@@ -1689,18 +1702,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
16891702
// br label %exit
16901703
// else:
16911704
// call @__kmpc_omp_task_begin_if0(...)
1692-
// call @wrapper_fn(...)
1705+
// call @outlined_fn(...)
16931706
// call @__kmpc_omp_task_complete_if0(...)
16941707
// br label %exit
16951708
// exit:
16961709
// ...
16971710
if (IfCondition) {
16981711
// `SplitBlockAndInsertIfThenElse` requires the block to have a
16991712
// terminator.
1700-
BasicBlock *NewBasicBlock =
1701-
splitBB(Builder, /*CreateBranch=*/true, "if.end");
1713+
splitBB(Builder, /*CreateBranch=*/true, "if.end");
17021714
Instruction *IfTerminator =
1703-
NewBasicBlock->getSinglePredecessor()->getTerminator();
1715+
Builder.GetInsertPoint()->getParent()->getTerminator();
17041716
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
17051717
Builder.SetInsertPoint(IfTerminator);
17061718
SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1723,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17111723
Function *TaskCompleteFn =
17121724
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
17131725
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1726+
CallInst *CI = nullptr;
17141727
if (HasShareds)
1715-
Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
1728+
CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
17161729
else
1717-
Builder.CreateCall(WrapperFunc, {ThreadID});
1730+
CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1731+
CI->setDebugLoc(StaleCI->getDebugLoc());
17181732
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
17191733
Builder.SetInsertPoint(ThenTI);
17201734
}
@@ -1736,26 +1750,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17361750

17371751
StaleCI->eraseFromParent();
17381752

1739-
// Emit the body for wrapper function
1740-
BasicBlock *WrapperEntryBB =
1741-
BasicBlock::Create(M.getContext(), "", WrapperFunc);
1742-
Builder.SetInsertPoint(WrapperEntryBB);
1753+
Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
17431754
if (HasShareds) {
1744-
llvm::Value *Shareds =
1745-
Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
1746-
Builder.CreateCall(&OutlinedFn, {Shareds});
1747-
} else {
1748-
Builder.CreateCall(&OutlinedFn);
1755+
LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
1756+
OutlinedFn.getArg(1)->replaceUsesWithIf(
1757+
Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
1758+
}
1759+
1760+
while (!ToBeDeleted.empty()) {
1761+
ToBeDeleted.top()->eraseFromParent();
1762+
ToBeDeleted.pop();
17491763
}
1750-
Builder.CreateRet(Builder.getInt32(0));
17511764
};
17521765

17531766
addOutlineInfo(std::move(OI));
1754-
1755-
InsertPointTy TaskAllocaIP =
1756-
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1757-
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1758-
BodyGenCB(TaskAllocaIP, TaskBodyIP);
17591767
Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
17601768

17611769
return Builder.saveIP();
@@ -5763,84 +5771,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
57635771
BasicBlock *AllocaBB =
57645772
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
57655773

5774+
// Generate the body of teams.
5775+
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
5776+
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
5777+
BodyGenCB(AllocaIP, CodeGenIP);
5778+
57665779
OutlineInfo OI;
57675780
OI.EntryBB = AllocaBB;
57685781
OI.ExitBB = ExitBB;
57695782
OI.OuterAllocaBB = &OuterAllocaBB;
5770-
OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
5771-
// The input IR here looks like the following-
5772-
// ```
5773-
// func @current_fn() {
5774-
// outlined_fn(%args)
5775-
// }
5776-
// func @outlined_fn(%args) { ... }
5777-
// ```
5778-
//
5779-
// This is changed to the following-
5780-
//
5781-
// ```
5782-
// func @current_fn() {
5783-
// runtime_call(..., wrapper_fn, ...)
5784-
// }
5785-
// func @wrapper_fn(..., %args) {
5786-
// outlined_fn(%args)
5787-
// }
5788-
// func @outlined_fn(%args) { ... }
5789-
// ```
57905783

5784+
// Insert fake values for global tid and bound tid.
5785+
std::stack<Instruction *> ToBeDeleted;
5786+
InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
5787+
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
5788+
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
5789+
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
5790+
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
5791+
5792+
OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
57915793
// The stale call instruction will be replaced with a new call instruction
5792-
// for runtime call with a wrapper function.
5794+
// for runtime call with the outlined function.
57935795

57945796
assert(OutlinedFn.getNumUses() == 1 &&
57955797
"there must be a single user for the outlined function");
57965798
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
5799+
ToBeDeleted.push(StaleCI);
5800+
5801+
assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
5802+
"Outlined function must have two or three arguments only");
5803+
5804+
bool HasShared = OutlinedFn.arg_size() == 3;
57975805

5798-
// Create the wrapper function.
5799-
SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
5800-
for (auto &Arg : OutlinedFn.args())
5801-
WrapperArgTys.push_back(Arg.getType());
5802-
FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
5803-
(Twine(OutlinedFn.getName()) + ".teams").str(),
5804-
FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
5805-
Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
5806-
WrapperFunc->getArg(0)->setName("global_tid");
5807-
WrapperFunc->getArg(1)->setName("bound_tid");
5808-
if (WrapperFunc->arg_size() > 2)
5809-
WrapperFunc->getArg(2)->setName("data");
5810-
5811-
// Emit the body of the wrapper function - just a call to outlined function
5812-
// and return statement.
5813-
BasicBlock *WrapperEntryBB =
5814-
BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
5815-
Builder.SetInsertPoint(WrapperEntryBB);
5816-
SmallVector<Value *> Args;
5817-
for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
5818-
Args.push_back(WrapperFunc->getArg(ArgIndex));
5819-
Builder.CreateCall(&OutlinedFn, Args);
5820-
Builder.CreateRetVoid();
5821-
5822-
OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
5806+
OutlinedFn.getArg(0)->setName("global.tid.ptr");
5807+
OutlinedFn.getArg(1)->setName("bound.tid.ptr");
5808+
if (HasShared)
5809+
OutlinedFn.getArg(2)->setName("data");
58235810

58245811
// Call to the runtime function for teams in the current function.
58255812
assert(StaleCI && "Error while outlining - no CallInst user found for the "
58265813
"outlined function.");
58275814
Builder.SetInsertPoint(StaleCI);
5828-
Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
5829-
for (Use &Arg : StaleCI->args())
5830-
Args.push_back(Arg);
5815+
SmallVector<Value *> Args = {
5816+
Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
5817+
if (HasShared)
5818+
Args.push_back(StaleCI->getArgOperand(2));
58315819
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
58325820
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
58335821
Args);
5834-
StaleCI->eraseFromParent();
5822+
5823+
while (!ToBeDeleted.empty()) {
5824+
ToBeDeleted.top()->eraseFromParent();
5825+
ToBeDeleted.pop();
5826+
}
58355827
};
58365828

58375829
addOutlineInfo(std::move(OI));
58385830

5839-
// Generate the body of teams.
5840-
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
5841-
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
5842-
BodyGenCB(AllocaIP, CodeGenIP);
5843-
58445831
Builder.SetInsertPoint(ExitBB, ExitBB->begin());
58455832

58465833
return Builder.saveIP();

0 commit comments

Comments
 (0)