Skip to content

[OpenMP][OMPIRBuilder] Add support to omp target parallel #67000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 226 additions & 85 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,13 @@ void OpenMPIRBuilder::finalize(Function *Fn) {

Function *OuterFn = OI.getFunction();
CodeExtractorAnalysisCache CEAC(*OuterFn);
// If we generate code for the target device, we need to allocate
// struct for aggregate params in the device default alloca address space.
// OpenMP runtime requires that the params of the extracted functions are
// passed as zero address space pointers. This flag ensures that
// CodeExtractor generates correct code for extracted functions
// which are used by OpenMP runtime.
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true,
/* BlockFrequencyInfo */ nullptr,
Expand All @@ -659,7 +666,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocaBlock*/ OI.OuterAllocaBB,
/* Suffix */ ".omp_par");
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);

LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
Expand Down Expand Up @@ -1101,6 +1108,182 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
}

// Callback used to create OpenMP runtime calls to support
// omp parallel clause for the device.
// We need to use this callback to replace call to the OutlinedFn in OuterFn
// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
static void targetParallelCallback(
OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
// Add some known attributes.
IRBuilder<> &Builder = OMPIRBuilder->Builder;
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
OutlinedFn.addParamAttr(0, Attribute::NoUndef);
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
OutlinedFn.addFnAttr(Attribute::NoUnwind);

assert(OutlinedFn.arg_size() >= 2 &&
"Expected at least tid and bounded tid as arguments");
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;

CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
assert(CI && "Expected call instruction to outlined function");
CI->getParent()->setName("omp_parallel");

Builder.SetInsertPoint(CI);
Type *PtrTy = OMPIRBuilder->VoidPtr;
Value *NullPtrValue = Constant::getNullValue(PtrTy);

// Add alloca for kernel args
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
AllocaInst *ArgsAlloca =
Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
Value *Args = ArgsAlloca;
// Add address space cast if array for storing arguments is not allocated
// in address space 0
if (ArgsAlloca->getAddressSpace())
Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
Builder.restoreIP(CurrentIP);

// Store captured vars which are used by kmpc_parallel_51
for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
Value *V = *(CI->arg_begin() + 2 + Idx);
Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
Builder.CreateStore(V, StoreAddress);
}

Value *Cond =
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
: Builder.getInt32(1);

// Build kmpc_parallel_51 call
Value *Parallel51CallArgs[] = {
/* identifier*/ Ident,
/* global thread num*/ ThreadID,
/* if expression */ Cond,
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
/* Proc bind */ Builder.getInt32(-1),
/* outlined function */
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
/* wrapper function */ NullPtrValue,
/* arguments of the outlined funciton*/ Args,
/* number of arguments */ Builder.getInt64(NumCapturedVars)};

FunctionCallee RTLFn =
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);

Builder.CreateCall(RTLFn, Parallel51CallArgs);

LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
<< *Builder.GetInsertBlock()->getParent() << "\n");

// Initialize the local TID stack location with the argument value.
Builder.SetInsertPoint(PrivTID);
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
PrivTIDAddr);

// Remove redundant call to the outlined function.
CI->eraseFromParent();

for (Instruction *I : ToBeDeleted) {
I->eraseFromParent();
}
}

// Callback used to create OpenMP runtime calls to support
// omp parallel clause for the host.
// We need to use this callback to replace call to the OutlinedFn in OuterFn
// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
static void
hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
Function *OuterFn, Value *Ident, Value *IfCondition,
Instruction *PrivTID, AllocaInst *PrivTIDAddr,
const SmallVector<Instruction *, 4> &ToBeDeleted) {
IRBuilder<> &Builder = OMPIRBuilder->Builder;
FunctionCallee RTLFn;
if (IfCondition) {
RTLFn =
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
} else {
RTLFn =
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
}
if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
if (!F->hasMetadata(LLVMContext::MD_callback)) {
LLVMContext &Ctx = F->getContext();
MDBuilder MDB(Ctx);
// Annotate the callback behavior of the __kmpc_fork_call:
// - The callback callee is argument number 2 (microtask).
// - The first two arguments of the callback callee are unknown (-1).
// - All variadic arguments to the __kmpc_fork_call are passed to the
// callback callee.
F->addMetadata(LLVMContext::MD_callback,
*MDNode::get(Ctx, {MDB.createCallbackEncoding(
2, {-1, -1},
/* VarArgsArePassed */ true)}));
}
}
// Add some known attributes.
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
OutlinedFn.addFnAttr(Attribute::NoUnwind);

assert(OutlinedFn.arg_size() >= 2 &&
"Expected at least tid and bounded tid as arguments");
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;

CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
CI->getParent()->setName("omp_parallel");
Builder.SetInsertPoint(CI);

// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
Value *ForkCallArgs[] = {
Ident, Builder.getInt32(NumCapturedVars),
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};

SmallVector<Value *, 16> RealArgs;
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
if (IfCondition) {
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
RealArgs.push_back(Cond);
}
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());

// __kmpc_fork_call_if always expects a void ptr as the last argument
// If there are no arguments, pass a null pointer.
auto PtrTy = OMPIRBuilder->VoidPtr;
if (IfCondition && NumCapturedVars == 0) {
Value *NullPtrValue = Constant::getNullValue(PtrTy);
RealArgs.push_back(NullPtrValue);
}
if (IfCondition && RealArgs.back()->getType() != PtrTy)
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look sound, but you just moved it, it's ok.


Builder.CreateCall(RTLFn, RealArgs);

LLVM_DEBUG(dbgs() << "With fork_call placed: "
<< *Builder.GetInsertBlock()->getParent() << "\n");

// Initialize the local TID stack location with the argument value.
Builder.SetInsertPoint(PrivTID);
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
PrivTIDAddr);

// Remove redundant call to the outlined function.
CI->eraseFromParent();

for (Instruction *I : ToBeDeleted) {
I->eraseFromParent();
}
}

IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
Expand All @@ -1115,6 +1298,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
Value *ThreadID = getOrCreateThreadID(Ident);
// If we generate code for the target device, we need to allocate
// struct for aggregate params in the device default alloca address space.
// OpenMP runtime requires that the params of the extracted functions are
// passed as zero address space pointers. This flag ensures that extracted
// function arguments are declared in zero address space
bool ArgsInZeroAddressSpace = Config.isTargetDevice();

if (NumThreads) {
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
Expand Down Expand Up @@ -1148,13 +1337,28 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
// Change the location to the outer alloca insertion point to create and
// initialize the allocas we pass into the parallel region.
Builder.restoreIP(OuterAllocaIP);
AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
AllocaInst *ZeroAddrAlloca =
Builder.CreateAlloca(Int32, nullptr, "zero.addr");
Instruction *TIDAddr = TIDAddrAlloca;
Instruction *ZeroAddr = ZeroAddrAlloca;
if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
// Add additional casts to enforce pointers in zero address space
TIDAddr = new AddrSpaceCastInst(
TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
TIDAddr->insertAfter(TIDAddrAlloca);
ToBeDeleted.push_back(TIDAddr);
ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
PointerType ::get(M.getContext(), 0),
"zero.addr.ascast");
ZeroAddr->insertAfter(ZeroAddrAlloca);
ToBeDeleted.push_back(ZeroAddr);
}

// We only need TIDAddr and ZeroAddr for modeling purposes to get the
// associated arguments in the outlined function, so we delete them later.
ToBeDeleted.push_back(TIDAddr);
ToBeDeleted.push_back(ZeroAddr);
ToBeDeleted.push_back(TIDAddrAlloca);
ToBeDeleted.push_back(ZeroAddrAlloca);

// Create an artificial insertion point that will also ensure the blocks we
// are about to split are not degenerated.
Expand Down Expand Up @@ -1222,87 +1426,24 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
BodyGenCB(InnerAllocaIP, CodeGenIP);

LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
FunctionCallee RTLFn;
if (IfCondition)
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
else
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);

if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
llvm::LLVMContext &Ctx = F->getContext();
MDBuilder MDB(Ctx);
// Annotate the callback behavior of the __kmpc_fork_call:
// - The callback callee is argument number 2 (microtask).
// - The first two arguments of the callback callee are unknown (-1).
// - All variadic arguments to the __kmpc_fork_call are passed to the
// callback callee.
F->addMetadata(
llvm::LLVMContext::MD_callback,
*llvm::MDNode::get(
Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
/* VarArgsArePassed */ true)}));
}
}

OutlineInfo OI;
OI.PostOutlineCB = [=](Function &OutlinedFn) {
// Add some known attributes.
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
OutlinedFn.addFnAttr(Attribute::NoUnwind);
OutlinedFn.addFnAttr(Attribute::NoRecurse);

assert(OutlinedFn.arg_size() >= 2 &&
"Expected at least tid and bounded tid as arguments");
unsigned NumCapturedVars =
OutlinedFn.arg_size() - /* tid & bounded tid */ 2;

CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
CI->getParent()->setName("omp_parallel");
Builder.SetInsertPoint(CI);

// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
Value *ForkCallArgs[] = {
Ident, Builder.getInt32(NumCapturedVars),
Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};

SmallVector<Value *, 16> RealArgs;
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
if (IfCondition) {
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
Type::getInt32Ty(M.getContext()));
RealArgs.push_back(Cond);
}
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());

// __kmpc_fork_call_if always expects a void ptr as the last argument
// If there are no arguments, pass a null pointer.
auto PtrTy = Type::getInt8PtrTy(M.getContext());
if (IfCondition && NumCapturedVars == 0) {
llvm::Value *Void = ConstantPointerNull::get(PtrTy);
RealArgs.push_back(Void);
}
if (IfCondition && RealArgs.back()->getType() != PtrTy)
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);

Builder.CreateCall(RTLFn, RealArgs);

LLVM_DEBUG(dbgs() << "With fork_call placed: "
<< *Builder.GetInsertBlock()->getParent() << "\n");

InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());

// Initialize the local TID stack location with the argument value.
Builder.SetInsertPoint(PrivTID);
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);

CI->eraseFromParent();

for (Instruction *I : ToBeDeleted)
I->eraseFromParent();
};
if (Config.isTargetDevice()) {
// Generate OpenMP target specific runtime call
OI.PostOutlineCB = [=, ToBeDeletedVec =
std::move(ToBeDeleted)](Function &OutlinedFn) {
targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
IfCondition, NumThreads, PrivTID, PrivTIDAddr,
ThreadID, ToBeDeletedVec);
};
} else {
// Generate OpenMP host runtime call
OI.PostOutlineCB = [=, ToBeDeletedVec =
std::move(ToBeDeleted)](Function &OutlinedFn) {
hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
PrivTID, PrivTIDAddr, ToBeDeletedVec);
};
}

// Adjust the finalization stack, verify the adjustment, and call the
// finalize function a last time to finalize values between the pre-fini
Expand Down Expand Up @@ -1342,7 +1483,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocationBlock */ OuterAllocaBlock,
/* Suffix */ ".omp_par");
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);

// Find inputs to, outputs from the code region.
BasicBlock *CommonExit = nullptr;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ struct OMPInformationCache : public InformationCache {
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
OpenMPPostLink(OpenMPPostLink) {

OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
OMPBuilder.initialize();
initializeRuntimeFunctions(M);
initializeInternalControlVars();
Expand Down
Loading