Skip to content

Commit 2cce0f6

Browse files
[OpenMP][OMPIRBuilder] Add support to omp target parallel (llvm#67000)
Added support for LLVM IR code generation which is used for handling omp target parallel code. The call for __kmpc_parallel_51 is generated and the parallel region is outlined to separate function. The proper setup of kmpc_target_init mode is not included in the commit. It is assumed that the SPMD mode for target initialization is properly set by other codegen functions.
1 parent 1a0e743 commit 2cce0f6

File tree

3 files changed

+361
-88
lines changed

3 files changed

+361
-88
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 226 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,13 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
651651

652652
Function *OuterFn = OI.getFunction();
653653
CodeExtractorAnalysisCache CEAC(*OuterFn);
654+
// If we generate code for the target device, we need to allocate
655+
// struct for aggregate params in the device default alloca address space.
656+
// OpenMP runtime requires that the params of the extracted functions are
657+
// passed as zero address space pointers. This flag ensures that
658+
// CodeExtractor generates correct code for extracted functions
659+
// which are used by OpenMP runtime.
660+
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
654661
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
655662
/* AggregateArgs */ true,
656663
/* BlockFrequencyInfo */ nullptr,
@@ -659,7 +666,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
659666
/* AllowVarArgs */ true,
660667
/* AllowAlloca */ true,
661668
/* AllocaBlock*/ OI.OuterAllocaBB,
662-
/* Suffix */ ".omp_par");
669+
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
663670

664671
LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
665672
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1101,6 +1108,182 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
11011108
Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
11021109
}
11031110

1111+
// Callback used to create OpenMP runtime calls to support
1112+
// omp parallel clause for the device.
1113+
// We need to use this callback to replace call to the OutlinedFn in OuterFn
1114+
// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1115+
static void targetParallelCallback(
1116+
OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1117+
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1118+
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1119+
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1120+
// Add some known attributes.
1121+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1122+
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1123+
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1124+
OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1125+
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1126+
OutlinedFn.addFnAttr(Attribute::NoUnwind);
1127+
1128+
assert(OutlinedFn.arg_size() >= 2 &&
1129+
"Expected at least tid and bounded tid as arguments");
1130+
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1131+
1132+
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1133+
assert(CI && "Expected call instruction to outlined function");
1134+
CI->getParent()->setName("omp_parallel");
1135+
1136+
Builder.SetInsertPoint(CI);
1137+
Type *PtrTy = OMPIRBuilder->VoidPtr;
1138+
Value *NullPtrValue = Constant::getNullValue(PtrTy);
1139+
1140+
// Add alloca for kernel args
1141+
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1142+
Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1143+
AllocaInst *ArgsAlloca =
1144+
Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1145+
Value *Args = ArgsAlloca;
1146+
// Add address space cast if array for storing arguments is not allocated
1147+
// in address space 0
1148+
if (ArgsAlloca->getAddressSpace())
1149+
Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1150+
Builder.restoreIP(CurrentIP);
1151+
1152+
// Store captured vars which are used by kmpc_parallel_51
1153+
for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1154+
Value *V = *(CI->arg_begin() + 2 + Idx);
1155+
Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1156+
ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1157+
Builder.CreateStore(V, StoreAddress);
1158+
}
1159+
1160+
Value *Cond =
1161+
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1162+
: Builder.getInt32(1);
1163+
1164+
// Build kmpc_parallel_51 call
1165+
Value *Parallel51CallArgs[] = {
1166+
/* identifier*/ Ident,
1167+
/* global thread num*/ ThreadID,
1168+
/* if expression */ Cond,
1169+
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1170+
/* Proc bind */ Builder.getInt32(-1),
1171+
/* outlined function */
1172+
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1173+
/* wrapper function */ NullPtrValue,
1174+
/* arguments of the outlined funciton*/ Args,
1175+
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
1176+
1177+
FunctionCallee RTLFn =
1178+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1179+
1180+
Builder.CreateCall(RTLFn, Parallel51CallArgs);
1181+
1182+
LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1183+
<< *Builder.GetInsertBlock()->getParent() << "\n");
1184+
1185+
// Initialize the local TID stack location with the argument value.
1186+
Builder.SetInsertPoint(PrivTID);
1187+
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1188+
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1189+
PrivTIDAddr);
1190+
1191+
// Remove redundant call to the outlined function.
1192+
CI->eraseFromParent();
1193+
1194+
for (Instruction *I : ToBeDeleted) {
1195+
I->eraseFromParent();
1196+
}
1197+
}
1198+
1199+
// Callback used to create OpenMP runtime calls to support
1200+
// omp parallel clause for the host.
1201+
// We need to use this callback to replace call to the OutlinedFn in OuterFn
1202+
// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1203+
static void
1204+
hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1205+
Function *OuterFn, Value *Ident, Value *IfCondition,
1206+
Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1207+
const SmallVector<Instruction *, 4> &ToBeDeleted) {
1208+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1209+
FunctionCallee RTLFn;
1210+
if (IfCondition) {
1211+
RTLFn =
1212+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1213+
} else {
1214+
RTLFn =
1215+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1216+
}
1217+
if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1218+
if (!F->hasMetadata(LLVMContext::MD_callback)) {
1219+
LLVMContext &Ctx = F->getContext();
1220+
MDBuilder MDB(Ctx);
1221+
// Annotate the callback behavior of the __kmpc_fork_call:
1222+
// - The callback callee is argument number 2 (microtask).
1223+
// - The first two arguments of the callback callee are unknown (-1).
1224+
// - All variadic arguments to the __kmpc_fork_call are passed to the
1225+
// callback callee.
1226+
F->addMetadata(LLVMContext::MD_callback,
1227+
*MDNode::get(Ctx, {MDB.createCallbackEncoding(
1228+
2, {-1, -1},
1229+
/* VarArgsArePassed */ true)}));
1230+
}
1231+
}
1232+
// Add some known attributes.
1233+
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1234+
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1235+
OutlinedFn.addFnAttr(Attribute::NoUnwind);
1236+
1237+
assert(OutlinedFn.arg_size() >= 2 &&
1238+
"Expected at least tid and bounded tid as arguments");
1239+
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1240+
1241+
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1242+
CI->getParent()->setName("omp_parallel");
1243+
Builder.SetInsertPoint(CI);
1244+
1245+
// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1246+
Value *ForkCallArgs[] = {
1247+
Ident, Builder.getInt32(NumCapturedVars),
1248+
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1249+
1250+
SmallVector<Value *, 16> RealArgs;
1251+
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1252+
if (IfCondition) {
1253+
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1254+
RealArgs.push_back(Cond);
1255+
}
1256+
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1257+
1258+
// __kmpc_fork_call_if always expects a void ptr as the last argument
1259+
// If there are no arguments, pass a null pointer.
1260+
auto PtrTy = OMPIRBuilder->VoidPtr;
1261+
if (IfCondition && NumCapturedVars == 0) {
1262+
Value *NullPtrValue = Constant::getNullValue(PtrTy);
1263+
RealArgs.push_back(NullPtrValue);
1264+
}
1265+
if (IfCondition && RealArgs.back()->getType() != PtrTy)
1266+
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1267+
1268+
Builder.CreateCall(RTLFn, RealArgs);
1269+
1270+
LLVM_DEBUG(dbgs() << "With fork_call placed: "
1271+
<< *Builder.GetInsertBlock()->getParent() << "\n");
1272+
1273+
// Initialize the local TID stack location with the argument value.
1274+
Builder.SetInsertPoint(PrivTID);
1275+
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1276+
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1277+
PrivTIDAddr);
1278+
1279+
// Remove redundant call to the outlined function.
1280+
CI->eraseFromParent();
1281+
1282+
for (Instruction *I : ToBeDeleted) {
1283+
I->eraseFromParent();
1284+
}
1285+
}
1286+
11041287
IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
11051288
const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
11061289
BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
@@ -1115,6 +1298,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
11151298
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
11161299
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
11171300
Value *ThreadID = getOrCreateThreadID(Ident);
1301+
// If we generate code for the target device, we need to allocate
1302+
// struct for aggregate params in the device default alloca address space.
1303+
// OpenMP runtime requires that the params of the extracted functions are
1304+
// passed as zero address space pointers. This flag ensures that extracted
1305+
// function arguments are declared in zero address space
1306+
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
11181307

11191308
if (NumThreads) {
11201309
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
@@ -1148,13 +1337,28 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
11481337
// Change the location to the outer alloca insertion point to create and
11491338
// initialize the allocas we pass into the parallel region.
11501339
Builder.restoreIP(OuterAllocaIP);
1151-
AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1152-
AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1340+
AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1341+
AllocaInst *ZeroAddrAlloca =
1342+
Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1343+
Instruction *TIDAddr = TIDAddrAlloca;
1344+
Instruction *ZeroAddr = ZeroAddrAlloca;
1345+
if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1346+
// Add additional casts to enforce pointers in zero address space
1347+
TIDAddr = new AddrSpaceCastInst(
1348+
TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1349+
TIDAddr->insertAfter(TIDAddrAlloca);
1350+
ToBeDeleted.push_back(TIDAddr);
1351+
ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1352+
PointerType ::get(M.getContext(), 0),
1353+
"zero.addr.ascast");
1354+
ZeroAddr->insertAfter(ZeroAddrAlloca);
1355+
ToBeDeleted.push_back(ZeroAddr);
1356+
}
11531357

11541358
// We only need TIDAddr and ZeroAddr for modeling purposes to get the
11551359
// associated arguments in the outlined function, so we delete them later.
1156-
ToBeDeleted.push_back(TIDAddr);
1157-
ToBeDeleted.push_back(ZeroAddr);
1360+
ToBeDeleted.push_back(TIDAddrAlloca);
1361+
ToBeDeleted.push_back(ZeroAddrAlloca);
11581362

11591363
// Create an artificial insertion point that will also ensure the blocks we
11601364
// are about to split are not degenerated.
@@ -1222,87 +1426,24 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
12221426
BodyGenCB(InnerAllocaIP, CodeGenIP);
12231427

12241428
LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1225-
FunctionCallee RTLFn;
1226-
if (IfCondition)
1227-
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1228-
else
1229-
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1230-
1231-
if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
1232-
if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
1233-
llvm::LLVMContext &Ctx = F->getContext();
1234-
MDBuilder MDB(Ctx);
1235-
// Annotate the callback behavior of the __kmpc_fork_call:
1236-
// - The callback callee is argument number 2 (microtask).
1237-
// - The first two arguments of the callback callee are unknown (-1).
1238-
// - All variadic arguments to the __kmpc_fork_call are passed to the
1239-
// callback callee.
1240-
F->addMetadata(
1241-
llvm::LLVMContext::MD_callback,
1242-
*llvm::MDNode::get(
1243-
Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
1244-
/* VarArgsArePassed */ true)}));
1245-
}
1246-
}
12471429

12481430
OutlineInfo OI;
1249-
OI.PostOutlineCB = [=](Function &OutlinedFn) {
1250-
// Add some known attributes.
1251-
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1252-
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1253-
OutlinedFn.addFnAttr(Attribute::NoUnwind);
1254-
OutlinedFn.addFnAttr(Attribute::NoRecurse);
1255-
1256-
assert(OutlinedFn.arg_size() >= 2 &&
1257-
"Expected at least tid and bounded tid as arguments");
1258-
unsigned NumCapturedVars =
1259-
OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1260-
1261-
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1262-
CI->getParent()->setName("omp_parallel");
1263-
Builder.SetInsertPoint(CI);
1264-
1265-
// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1266-
Value *ForkCallArgs[] = {
1267-
Ident, Builder.getInt32(NumCapturedVars),
1268-
Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
1269-
1270-
SmallVector<Value *, 16> RealArgs;
1271-
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1272-
if (IfCondition) {
1273-
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
1274-
Type::getInt32Ty(M.getContext()));
1275-
RealArgs.push_back(Cond);
1276-
}
1277-
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1278-
1279-
// __kmpc_fork_call_if always expects a void ptr as the last argument
1280-
// If there are no arguments, pass a null pointer.
1281-
auto PtrTy = Type::getInt8PtrTy(M.getContext());
1282-
if (IfCondition && NumCapturedVars == 0) {
1283-
llvm::Value *Void = ConstantPointerNull::get(PtrTy);
1284-
RealArgs.push_back(Void);
1285-
}
1286-
if (IfCondition && RealArgs.back()->getType() != PtrTy)
1287-
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1288-
1289-
Builder.CreateCall(RTLFn, RealArgs);
1290-
1291-
LLVM_DEBUG(dbgs() << "With fork_call placed: "
1292-
<< *Builder.GetInsertBlock()->getParent() << "\n");
1293-
1294-
InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
1295-
1296-
// Initialize the local TID stack location with the argument value.
1297-
Builder.SetInsertPoint(PrivTID);
1298-
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1299-
Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
1300-
1301-
CI->eraseFromParent();
1302-
1303-
for (Instruction *I : ToBeDeleted)
1304-
I->eraseFromParent();
1305-
};
1431+
if (Config.isTargetDevice()) {
1432+
// Generate OpenMP target specific runtime call
1433+
OI.PostOutlineCB = [=, ToBeDeletedVec =
1434+
std::move(ToBeDeleted)](Function &OutlinedFn) {
1435+
targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1436+
IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1437+
ThreadID, ToBeDeletedVec);
1438+
};
1439+
} else {
1440+
// Generate OpenMP host runtime call
1441+
OI.PostOutlineCB = [=, ToBeDeletedVec =
1442+
std::move(ToBeDeleted)](Function &OutlinedFn) {
1443+
hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1444+
PrivTID, PrivTIDAddr, ToBeDeletedVec);
1445+
};
1446+
}
13061447

13071448
// Adjust the finalization stack, verify the adjustment, and call the
13081449
// finalize function a last time to finalize values between the pre-fini
@@ -1342,7 +1483,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
13421483
/* AllowVarArgs */ true,
13431484
/* AllowAlloca */ true,
13441485
/* AllocationBlock */ OuterAllocaBlock,
1345-
/* Suffix */ ".omp_par");
1486+
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
13461487

13471488
// Find inputs to, outputs from the code region.
13481489
BasicBlock *CommonExit = nullptr;

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ struct OMPInformationCache : public InformationCache {
286286
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287287
OpenMPPostLink(OpenMPPostLink) {
288288

289+
OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
289290
OMPBuilder.initialize();
290291
initializeRuntimeFunctions(M);
291292
initializeInternalControlVars();

0 commit comments

Comments
 (0)