diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 75da461cfd8d9..b505daae7e75f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2166,6 +2166,10 @@ class OpenMPIRBuilder { using TargetBodyGenCallbackTy = function_ref; + using TargetGenArgAccessorsCallbackTy = function_ref; + /// Generator for '#omp target' /// /// \param Loc where the target data construct was encountered. @@ -2177,6 +2181,8 @@ class OpenMPIRBuilder { /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. + /// \param ArgAccessorFuncCB Callback that will generate accessors + /// instructions for passed in target arguments where neccessary InsertPointTy createTarget(const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, @@ -2184,7 +2190,8 @@ class OpenMPIRBuilder { int32_t NumThreads, SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, - TargetBodyGenCallbackTy BodyGenCB); + TargetBodyGenCallbackTy BodyGenCB, + TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB); /// Returns __kmpc_for_static_init_* runtime function for the specified /// size \a IVSize and sign \a IVSigned. Will create a distribute call diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 72e1af55fe63f..c95dbbe996660 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4539,25 +4539,11 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize, return getOrCreateRuntimeFunction(M, Name); } -// Copy input from pointer or i64 to the expected argument type. -static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace, - Value *Input, Argument &Arg) { - auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy() - ? Arg.getType() - : Type::getInt64Ty(Builder.getContext()), - AddrSpace); - auto AddrAscast = - Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); - Builder.CreateStore(&Arg, AddrAscast); - auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast); - - return Copy; -} - -static Function * -createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - StringRef FuncName, SmallVectorImpl &Inputs, - OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { +static Function *createOutlinedFunction( + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) { SmallVector ParameterTypes; if (OMPBuilder.Config.isTargetDevice()) { // All parameters to target devices are passed as pointers @@ -4597,18 +4583,20 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, // Insert return instruction. Builder.CreateRetVoid(); - // Rewrite uses of input valus to parameters. + // New Alloca IP at entry point of created device function. + Builder.SetInsertPoint(EntryBB->getFirstNonPHI()); + auto AllocaIP = Builder.saveIP(); + Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg()); + + // Rewrite uses of input valus to parameters. for (auto InArg : zip(Inputs, Func->args())) { Value *Input = std::get<0>(InArg); Argument &Arg = std::get<1>(InArg); + Value *InputCopy = nullptr; - Value *InputCopy = - OMPBuilder.Config.isTargetDevice() - ? copyInput(Builder, - OMPBuilder.M.getDataLayout().getAllocaAddrSpace(), - Input, Arg) - : &Arg; + Builder.restoreIP( + ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP())); // Collect all the instructions for (User *User : make_early_inc_range(Input->users())) @@ -4623,18 +4611,19 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Func; } -static void -emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - TargetRegionEntryInfo &EntryInfo, - Function *&OutlinedFn, Constant *&OutlinedFnID, - int32_t NumTeams, int32_t NumThreads, - SmallVectorImpl &Inputs, - OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { +static void emitTargetOutlinedFunction( + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn, + Constant *&OutlinedFnID, int32_t NumTeams, int32_t NumThreads, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) { OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = - [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + [&OMPBuilder, &Builder, &Inputs, &CBFunc, + &ArgAccessorFuncCB](StringRef EntryFnName) { return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs, - CBFunc); + CBFunc, ArgAccessorFuncCB); }; OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, @@ -4698,7 +4687,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads, SmallVectorImpl &Args, - GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) { + GenMapInfoCallbackTy GenMapInfoCB, + OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -4707,7 +4698,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( Function *OutlinedFn; Constant *OutlinedFnID; emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, - OutlinedFnID, NumTeams, NumThreads, Args, CBFunc); + OutlinedFnID, NumTeams, NumThreads, Args, CBFunc, + ArgAccessorFuncCB); if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, NumThreads, Args, GenMapInfoCB); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 4704d8cb492b7..5de9a7073604a 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5236,6 +5236,33 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { Inputs.push_back(BPtr); Inputs.push_back(CPtr); + auto SimpleArgAccessorCB = + [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal, + llvm::OpenMPIRBuilder::InsertPointTy AllocaIP, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + if (!OMPBuilder.Config.isTargetDevice()) { + RetVal = cast(&Arg); + return CodeGenIP; + } + + Builder.restoreIP(AllocaIP); + + llvm::Value *Addr = Builder.CreateAlloca( + Arg.getType()->isPointerTy() + ? Arg.getType() + : Type::getInt64Ty(Builder.getContext()), + OMPBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *AddrAscast = + Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); + Builder.CreateStore(&Arg, AddrAscast); + + Builder.restoreIP(CodeGenIP); + + RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast); + + return Builder.saveIP(); + }; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { @@ -5245,9 +5272,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), - Builder.saveIP(), EntryInfo, -1, 0, - Inputs, GenMapInfoCB, BodyGenCB)); + Builder.restoreIP(OMPBuilder.createTarget( + OmpLoc, Builder.saveIP(), Builder.saveIP(), EntryInfo, -1, 0, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); OMPBuilder.finalize(); Builder.CreateRetVoid(); @@ -5301,6 +5328,33 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { Constant::getNullValue(PointerType::get(Ctx, 0)), Constant::getNullValue(PointerType::get(Ctx, 0))}; + auto SimpleArgAccessorCB = + [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal, + llvm::OpenMPIRBuilder::InsertPointTy AllocaIP, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + if (!OMPBuilder.Config.isTargetDevice()) { + RetVal = cast(&Arg); + return CodeGenIP; + } + + Builder.restoreIP(AllocaIP); + + llvm::Value *Addr = Builder.CreateAlloca( + Arg.getType()->isPointerTy() + ? Arg.getType() + : Type::getInt64Ty(Builder.getContext()), + OMPBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *AddrAscast = + Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); + Builder.CreateStore(&Arg, AddrAscast); + + Builder.restoreIP(CodeGenIP); + + RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast); + + return Builder.saveIP(); + }; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { @@ -5322,9 +5376,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, /*Line=*/3, /*Count=*/0); - Builder.restoreIP(OMPBuilder.createTarget( - Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB)); + Builder.restoreIP( + OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -5343,10 +5398,18 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { // Check entry block auto &EntryBlock = OutlinedFn->getEntryBlock(); - Instruction *Init = EntryBlock.getFirstNonPHI(); - EXPECT_NE(Init, nullptr); + Instruction *Alloca1 = EntryBlock.getFirstNonPHI(); + EXPECT_NE(Alloca1, nullptr); + + EXPECT_TRUE(isa(Alloca1)); + auto *Store1 = Alloca1->getNextNode(); + EXPECT_TRUE(isa(Store1)); + auto *Alloca2 = Store1->getNextNode(); + EXPECT_TRUE(isa(Alloca2)); + auto *Store2 = Alloca2->getNextNode(); + EXPECT_TRUE(isa(Store2)); - auto *InitCall = dyn_cast(Init); + auto *InitCall = dyn_cast(Store2->getNextNode()); EXPECT_NE(InitCall, nullptr); EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init"); EXPECT_EQ(InitCall->arg_size(), 1U); @@ -5370,17 +5433,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { // Check user code block auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0); EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry"); - auto *Alloca1 = UserCodeBlock->getFirstNonPHI(); - EXPECT_TRUE(isa(Alloca1)); - auto *Store1 = Alloca1->getNextNode(); - EXPECT_TRUE(isa(Store1)); - auto *Load1 = Store1->getNextNode(); + auto *Load1 = UserCodeBlock->getFirstNonPHI(); EXPECT_TRUE(isa(Load1)); - auto *Alloca2 = Load1->getNextNode(); - EXPECT_TRUE(isa(Alloca2)); - auto *Store2 = Alloca2->getNextNode(); - EXPECT_TRUE(isa(Store2)); - auto *Load2 = Store2->getNextNode(); + auto *Load2 = Load1->getNextNode(); EXPECT_TRUE(isa(Load2)); auto *Value1 = Load2->getNextNode(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 27e22c3cef430..1ec3bb8e7562a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2018,6 +2018,31 @@ handleDeclareTargetMapVar(llvm::ArrayRef mapOperands, } } +static llvm::IRBuilderBase::InsertPoint +createDeviceArgumentAccessor(llvm::Argument &arg, llvm::Value *input, + llvm::Value *&retVal, llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase::InsertPoint allocaIP, + llvm::IRBuilderBase::InsertPoint codeGenIP) { + builder.restoreIP(allocaIP); + + llvm::Value *addr = + builder.CreateAlloca(arg.getType()->isPointerTy() + ? arg.getType() + : llvm::Type::getInt64Ty(builder.getContext()), + ompBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *addrAscast = + builder.CreatePointerBitCastOrAddrSpaceCast(addr, input->getType()); + builder.CreateStore(&arg, addrAscast); + + builder.restoreIP(codeGenIP); + + retVal = builder.CreateLoad(arg.getType(), addrAscast); + + return builder.saveIP(); +} + static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -2109,9 +2134,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, return combinedInfos; }; + auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input, + llvm::Value *&retVal, InsertPointTy allocaIP, + InsertPointTy codeGenIP) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + // We just return the unaltered argument for the host function + // for now, some alterations may be required in the future to + // keep host fallback functions working identically to the device + // version (e.g. pass ByCopy values should be treated as such on + // host and device, currently not always the case) + if (!ompBuilder->Config.isTargetDevice()) { + retVal = cast(&arg); + return codeGenIP; + } + + return createDeviceArgumentAccessor(arg, input, retVal, builder, + *ompBuilder, moduleTranslation, + allocaIP, codeGenIP); + }; + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams, - defaultValThreads, inputs, genMapInfoCB, bodyCB)); + defaultValThreads, inputs, genMapInfoCB, bodyCB, argAccessorCB)); // Remap access operations to declare target reference pointers for the // device, essentially generating extra loadop's as necessary diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir index cf70469e7484f..99f1a3b072ad8 100644 --- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -31,18 +31,18 @@ module attributes {omp.is_target_device = true} { // CHECK: @[[DYNA_ENV:.*]] = weak_odr protected global %struct.DynamicEnvironmentTy zeroinitializer // CHECK: @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] } // CHECK: define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) -// CHECK: %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]]) -// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %3, -1 -// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] -// CHECK: [[LABEL_ENTRY]]: // CHECK: %[[TMP_A:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8 -// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8 // CHECK: %[[TMP_B:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8 -// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8 // CHECK: %[[TMP_C:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8 +// CHECK: %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]]) +// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %[[INIT]], -1 +// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] +// CHECK: [[LABEL_ENTRY]]: +// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8 +// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8 // CHECK: %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8 // CHECK-NEXT: br label %[[LABEL_TARGET:.*]] // CHECK: [[LABEL_TARGET]]: