From 8f88cb74ef00e4f86e815419343362a676224521 Mon Sep 17 00:00:00 2001 From: Andrew Gozillon Date: Fri, 6 Oct 2023 08:58:27 -0500 Subject: [PATCH] [OpenMP][OpenMPIRBuilder] Move copyInput to a passed in lambda function and re-order kernel argument load/stores This patch moves the existing copyInput function into a lambda argument that can be defined by a caller to the function. This allows more flexibility in how the function is defined, allowing Clang and MLIR to utilise their own respective functions and types inside of the lamba without affecting the OMPIRBuilder itself. The idea is to eventually replace/build on the existing copyInput function that's used and moved into OpenMPToLLVMIRTranslation.cpp to a slightly more complex implementation that uses MLIRs map information (primarily ByRef and ByCapture information at the moment). The patch also moves kernel load stores to the top of the kernel, prior to the first openmp runtime invocation. Just makes the IR a little closer to Clang. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 9 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 66 ++++++------- .../Frontend/OpenMPIRBuilderTest.cpp | 93 +++++++++++++++---- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 47 +++++++++- .../LLVMIR/omptarget-region-device-llvm.mlir | 12 +-- 5 files changed, 163 insertions(+), 64 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 75da461cfd8d9..b505daae7e75f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2166,6 +2166,10 @@ class OpenMPIRBuilder { using TargetBodyGenCallbackTy = function_ref; + using TargetGenArgAccessorsCallbackTy = function_ref; + /// Generator for '#omp target' /// /// \param Loc where the target data construct was encountered. @@ -2177,6 +2181,8 @@ class OpenMPIRBuilder { /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. + /// \param ArgAccessorFuncCB Callback that will generate accessors + /// instructions for passed in target arguments where neccessary InsertPointTy createTarget(const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, @@ -2184,7 +2190,8 @@ class OpenMPIRBuilder { int32_t NumThreads, SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, - TargetBodyGenCallbackTy BodyGenCB); + TargetBodyGenCallbackTy BodyGenCB, + TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB); /// Returns __kmpc_for_static_init_* runtime function for the specified /// size \a IVSize and sign \a IVSigned. Will create a distribute call diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 72e1af55fe63f..c95dbbe996660 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4539,25 +4539,11 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize, return getOrCreateRuntimeFunction(M, Name); } -// Copy input from pointer or i64 to the expected argument type. -static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace, - Value *Input, Argument &Arg) { - auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy() - ? Arg.getType() - : Type::getInt64Ty(Builder.getContext()), - AddrSpace); - auto AddrAscast = - Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); - Builder.CreateStore(&Arg, AddrAscast); - auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast); - - return Copy; -} - -static Function * -createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - StringRef FuncName, SmallVectorImpl &Inputs, - OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { +static Function *createOutlinedFunction( + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) { SmallVector ParameterTypes; if (OMPBuilder.Config.isTargetDevice()) { // All parameters to target devices are passed as pointers @@ -4597,18 +4583,20 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, // Insert return instruction. Builder.CreateRetVoid(); - // Rewrite uses of input valus to parameters. + // New Alloca IP at entry point of created device function. + Builder.SetInsertPoint(EntryBB->getFirstNonPHI()); + auto AllocaIP = Builder.saveIP(); + Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg()); + + // Rewrite uses of input valus to parameters. for (auto InArg : zip(Inputs, Func->args())) { Value *Input = std::get<0>(InArg); Argument &Arg = std::get<1>(InArg); + Value *InputCopy = nullptr; - Value *InputCopy = - OMPBuilder.Config.isTargetDevice() - ? copyInput(Builder, - OMPBuilder.M.getDataLayout().getAllocaAddrSpace(), - Input, Arg) - : &Arg; + Builder.restoreIP( + ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP())); // Collect all the instructions for (User *User : make_early_inc_range(Input->users())) @@ -4623,18 +4611,19 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, return Func; } -static void -emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - TargetRegionEntryInfo &EntryInfo, - Function *&OutlinedFn, Constant *&OutlinedFnID, - int32_t NumTeams, int32_t NumThreads, - SmallVectorImpl &Inputs, - OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { +static void emitTargetOutlinedFunction( + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn, + Constant *&OutlinedFnID, int32_t NumTeams, int32_t NumThreads, + SmallVectorImpl &Inputs, + OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) { OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = - [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + [&OMPBuilder, &Builder, &Inputs, &CBFunc, + &ArgAccessorFuncCB](StringRef EntryFnName) { return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs, - CBFunc); + CBFunc, ArgAccessorFuncCB); }; OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, @@ -4698,7 +4687,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads, SmallVectorImpl &Args, - GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) { + GenMapInfoCallbackTy GenMapInfoCB, + OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, + OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -4707,7 +4698,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( Function *OutlinedFn; Constant *OutlinedFnID; emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, - OutlinedFnID, NumTeams, NumThreads, Args, CBFunc); + OutlinedFnID, NumTeams, NumThreads, Args, CBFunc, + ArgAccessorFuncCB); if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, NumThreads, Args, GenMapInfoCB); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 4704d8cb492b7..5de9a7073604a 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5236,6 +5236,33 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { Inputs.push_back(BPtr); Inputs.push_back(CPtr); + auto SimpleArgAccessorCB = + [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal, + llvm::OpenMPIRBuilder::InsertPointTy AllocaIP, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + if (!OMPBuilder.Config.isTargetDevice()) { + RetVal = cast(&Arg); + return CodeGenIP; + } + + Builder.restoreIP(AllocaIP); + + llvm::Value *Addr = Builder.CreateAlloca( + Arg.getType()->isPointerTy() + ? Arg.getType() + : Type::getInt64Ty(Builder.getContext()), + OMPBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *AddrAscast = + Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); + Builder.CreateStore(&Arg, AddrAscast); + + Builder.restoreIP(CodeGenIP); + + RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast); + + return Builder.saveIP(); + }; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { @@ -5245,9 +5272,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), - Builder.saveIP(), EntryInfo, -1, 0, - Inputs, GenMapInfoCB, BodyGenCB)); + Builder.restoreIP(OMPBuilder.createTarget( + OmpLoc, Builder.saveIP(), Builder.saveIP(), EntryInfo, -1, 0, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); OMPBuilder.finalize(); Builder.CreateRetVoid(); @@ -5301,6 +5328,33 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { Constant::getNullValue(PointerType::get(Ctx, 0)), Constant::getNullValue(PointerType::get(Ctx, 0))}; + auto SimpleArgAccessorCB = + [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal, + llvm::OpenMPIRBuilder::InsertPointTy AllocaIP, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + if (!OMPBuilder.Config.isTargetDevice()) { + RetVal = cast(&Arg); + return CodeGenIP; + } + + Builder.restoreIP(AllocaIP); + + llvm::Value *Addr = Builder.CreateAlloca( + Arg.getType()->isPointerTy() + ? Arg.getType() + : Type::getInt64Ty(Builder.getContext()), + OMPBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *AddrAscast = + Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); + Builder.CreateStore(&Arg, AddrAscast); + + Builder.restoreIP(CodeGenIP); + + RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast); + + return Builder.saveIP(); + }; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { @@ -5322,9 +5376,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, /*Line=*/3, /*Count=*/0); - Builder.restoreIP(OMPBuilder.createTarget( - Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1, - /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB)); + Builder.restoreIP( + OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -5343,10 +5398,18 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { // Check entry block auto &EntryBlock = OutlinedFn->getEntryBlock(); - Instruction *Init = EntryBlock.getFirstNonPHI(); - EXPECT_NE(Init, nullptr); + Instruction *Alloca1 = EntryBlock.getFirstNonPHI(); + EXPECT_NE(Alloca1, nullptr); + + EXPECT_TRUE(isa(Alloca1)); + auto *Store1 = Alloca1->getNextNode(); + EXPECT_TRUE(isa(Store1)); + auto *Alloca2 = Store1->getNextNode(); + EXPECT_TRUE(isa(Alloca2)); + auto *Store2 = Alloca2->getNextNode(); + EXPECT_TRUE(isa(Store2)); - auto *InitCall = dyn_cast(Init); + auto *InitCall = dyn_cast(Store2->getNextNode()); EXPECT_NE(InitCall, nullptr); EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init"); EXPECT_EQ(InitCall->arg_size(), 1U); @@ -5370,17 +5433,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { // Check user code block auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0); EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry"); - auto *Alloca1 = UserCodeBlock->getFirstNonPHI(); - EXPECT_TRUE(isa(Alloca1)); - auto *Store1 = Alloca1->getNextNode(); - EXPECT_TRUE(isa(Store1)); - auto *Load1 = Store1->getNextNode(); + auto *Load1 = UserCodeBlock->getFirstNonPHI(); EXPECT_TRUE(isa(Load1)); - auto *Alloca2 = Load1->getNextNode(); - EXPECT_TRUE(isa(Alloca2)); - auto *Store2 = Alloca2->getNextNode(); - EXPECT_TRUE(isa(Store2)); - auto *Load2 = Store2->getNextNode(); + auto *Load2 = Load1->getNextNode(); EXPECT_TRUE(isa(Load2)); auto *Value1 = Load2->getNextNode(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 27e22c3cef430..1ec3bb8e7562a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2018,6 +2018,31 @@ handleDeclareTargetMapVar(llvm::ArrayRef mapOperands, } } +static llvm::IRBuilderBase::InsertPoint +createDeviceArgumentAccessor(llvm::Argument &arg, llvm::Value *input, + llvm::Value *&retVal, llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase::InsertPoint allocaIP, + llvm::IRBuilderBase::InsertPoint codeGenIP) { + builder.restoreIP(allocaIP); + + llvm::Value *addr = + builder.CreateAlloca(arg.getType()->isPointerTy() + ? arg.getType() + : llvm::Type::getInt64Ty(builder.getContext()), + ompBuilder.M.getDataLayout().getAllocaAddrSpace()); + llvm::Value *addrAscast = + builder.CreatePointerBitCastOrAddrSpaceCast(addr, input->getType()); + builder.CreateStore(&arg, addrAscast); + + builder.restoreIP(codeGenIP); + + retVal = builder.CreateLoad(arg.getType(), addrAscast); + + return builder.saveIP(); +} + static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -2109,9 +2134,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, return combinedInfos; }; + auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input, + llvm::Value *&retVal, InsertPointTy allocaIP, + InsertPointTy codeGenIP) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + // We just return the unaltered argument for the host function + // for now, some alterations may be required in the future to + // keep host fallback functions working identically to the device + // version (e.g. pass ByCopy values should be treated as such on + // host and device, currently not always the case) + if (!ompBuilder->Config.isTargetDevice()) { + retVal = cast(&arg); + return codeGenIP; + } + + return createDeviceArgumentAccessor(arg, input, retVal, builder, + *ompBuilder, moduleTranslation, + allocaIP, codeGenIP); + }; + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams, - defaultValThreads, inputs, genMapInfoCB, bodyCB)); + defaultValThreads, inputs, genMapInfoCB, bodyCB, argAccessorCB)); // Remap access operations to declare target reference pointers for the // device, essentially generating extra loadop's as necessary diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir index cf70469e7484f..99f1a3b072ad8 100644 --- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -31,18 +31,18 @@ module attributes {omp.is_target_device = true} { // CHECK: @[[DYNA_ENV:.*]] = weak_odr protected global %struct.DynamicEnvironmentTy zeroinitializer // CHECK: @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] } // CHECK: define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) -// CHECK: %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]]) -// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %3, -1 -// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] -// CHECK: [[LABEL_ENTRY]]: // CHECK: %[[TMP_A:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8 -// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8 // CHECK: %[[TMP_B:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8 -// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8 // CHECK: %[[TMP_C:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8 +// CHECK: %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]]) +// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %[[INIT]], -1 +// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] +// CHECK: [[LABEL_ENTRY]]: +// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8 +// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8 // CHECK: %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8 // CHECK-NEXT: br label %[[LABEL_TARGET:.*]] // CHECK: [[LABEL_TARGET]]: