diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index a4d2524bccf5c..cf91b2638aecc 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter, static void addUseDeviceClause( lower::AbstractConverter &converter, const omp::ObjectList &objects, llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) { genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) { + for (mlir::Value &operand : operands) checkMapType(operand.getLoc(), operand.getType()); - useDeviceTypes.push_back(operand.getType()); - useDeviceLocs.push_back(operand.getLoc()); - } + for (const omp::Object &object : objects) useDeviceSyms.push_back(object.sym()); } @@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { bool ClauseProcessor::processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const { + llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::HasDeviceAddr &devAddrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -864,14 +858,12 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const { + llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::IsDevicePtr &devPtrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -892,9 +884,7 @@ void ClauseProcessor::processMapObjects( std::map> &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymTypes) const { + llvm::SmallVectorImpl &mapSyms) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); for (const omp::Object &object : objects) { llvm::SmallVector bounds; @@ -927,12 +917,7 @@ void ClauseProcessor::processMapObjects( addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx); } else { mapVars.push_back(mapOp); - if (mapSyms) - mapSyms->push_back(object.sym()); - if (mapSymTypes) - mapSymTypes->push_back(baseOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(baseOp.getLoc()); + mapSyms.push_back(object.sym()); } } } @@ -940,9 +925,7 @@ void ClauseProcessor::processMapObjects( bool ClauseProcessor::processMap( mlir::Location currentLocation, lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymTypes) const { + llvm::SmallVectorImpl *mapSyms) const { // We always require tracking of symbols, even if the caller does not, // so we create an optionally used local set of symbols when the mapSyms // argument is not present. @@ -999,12 +982,11 @@ bool ClauseProcessor::processMap( } processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, - parentMemberIndices, result.mapVars, ptrMapSyms, - mapSymLocs, mapSymTypes); + parentMemberIndices, result.mapVars, *ptrMapSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars, - *ptrMapSyms, mapSymTypes, mapSymLocs); + *ptrMapSyms); return clauseFound; } @@ -1027,7 +1009,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, parentMemberIndices, result.mapVars, - &mapSymbols); + mapSymbols); }; bool clauseFound = findRepeatableClause(callbackFn); @@ -1035,8 +1017,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, findRepeatableClause(callbackFn) || clauseFound; insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars, - mapSymbols, - /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr); + mapSymbols); return clauseFound; } @@ -1054,8 +1035,7 @@ bool ClauseProcessor::processNontemporal( bool ClauseProcessor::processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl *outReductionTypes, - llvm::SmallVectorImpl *outReductionSyms) const { + llvm::SmallVectorImpl &outReductionSyms) const { return findRepeatableClause( [&](const omp::clause::Reduction &clause, const parser::CharBlock &) { llvm::SmallVector reductionVars; @@ -1063,25 +1043,16 @@ bool ClauseProcessor::processReduction( llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction( - currentLocation, converter, clause, reductionVars, reduceVarByRef, - reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr); + rp.addDeclareReduction(currentLocation, converter, clause, + reductionVars, reduceVarByRef, + reductionDeclSymbols, reductionSyms); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); llvm::copy(reductionDeclSymbols, std::back_inserter(result.reductionSyms)); - - if (outReductionTypes) { - outReductionTypes->reserve(outReductionTypes->size() + - reductionVars.size()); - llvm::transform(reductionVars, std::back_inserter(*outReductionTypes), - [](mlir::Value v) { return v.getType(); }); - } - - if (outReductionSyms) - llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms)); + llvm::copy(reductionSyms, std::back_inserter(outReductionSyms)); }); } @@ -1107,8 +1078,6 @@ bool ClauseProcessor::processEnter( bool ClauseProcessor::processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const { std::map> @@ -1122,19 +1091,16 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, - result.useDeviceAddrVars, useDeviceSyms, - &useDeviceTypes, &useDeviceLocs); + result.useDeviceAddrVars, useDeviceSyms); return clauseFound; } bool ClauseProcessor::processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const { std::map> @@ -1148,12 +1114,11 @@ bool ClauseProcessor::processUseDevicePtr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, parentMemberIndices, - result.useDevicePtrVars, useDeviceSyms, - &useDeviceTypes, &useDeviceLocs); + result.useDevicePtrVars, useDeviceSyms); return clauseFound; } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 0c8e7bd47ab5a..f34121c70d0b4 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -68,9 +68,7 @@ class ClauseProcessor { mlir::omp::FinalClauseOps &result) const; bool processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const; + llvm::SmallVectorImpl &isDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; @@ -104,43 +102,33 @@ class ClauseProcessor { mlir::omp::IfClauseOps &result) const; bool processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const; + llvm::SmallVectorImpl &isDeviceSyms) const; bool processLink(llvm::SmallVectorImpl &result) const; // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to - // store the original type, location and Fortran symbol for the map operands. - // They may be used later on to create the block_arguments for some of the - // target directives that require it. - bool processMap( - mlir::Location currentLocation, lower::StatementContext &stmtCtx, - mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl *mapSyms = nullptr, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + // The optional parameter mapSyms is used to store the original Fortran symbol + // for the map operands. It may be used later on to create the block_arguments + // for some of the directives that require it. + bool processMap(mlir::Location currentLocation, + lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms = + nullptr) const; bool processMotionClauses(lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result); bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; bool processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl *reductionTypes = nullptr, - llvm::SmallVectorImpl *reductionSyms = - nullptr) const; + llvm::SmallVectorImpl &reductionSyms) const; bool processTo(llvm::SmallVectorImpl &result) const; bool processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const; bool processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const; // Call this method for these clauses that should be supported but are not @@ -181,9 +169,7 @@ class ClauseProcessor { std::map> &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + llvm::SmallVectorImpl &mapSyms) const; lower::AbstractConverter &converter; semantics::SemanticsContext &semaCtx; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 8195f4a897a90..b1a10960c8022 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,6 +45,40 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +namespace { +/// Structure holding the information needed to create and bind entry block +/// arguments associated to a single clause. +struct EntryBlockArgsEntry { + llvm::ArrayRef syms; + llvm::ArrayRef vars; + + bool isValid() const { + // This check allows specifying a smaller number of symbols than values + // because in some case cases a single symbol generates multiple block + // arguments. + return syms.size() <= vars.size(); + } +}; + +/// Structure holding the information needed to create and bind entry block +/// arguments associated to all clauses that can define them. +struct EntryBlockArgs { + EntryBlockArgsEntry inReduction; + EntryBlockArgsEntry map; + EntryBlockArgsEntry priv; + EntryBlockArgsEntry reduction; + EntryBlockArgsEntry taskReduction; + EntryBlockArgsEntry useDeviceAddr; + EntryBlockArgsEntry useDevicePtr; + + bool isValid() const { + return inReduction.isValid() && map.isValid() && priv.isValid() && + reduction.isValid() && taskReduction.isValid() && + useDeviceAddr.isValid() && useDevicePtr.isValid(); + } +}; +} // namespace + static void genOMPDispatch(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, @@ -52,6 +86,164 @@ static void genOMPDispatch(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item); +/// Bind symbols to their corresponding entry block arguments. +/// +/// The binding will be performed inside of the current block, which does not +/// necessarily have to be part of the operation for which the binding is done. +/// However, block arguments must be accessible. This enables controlling the +/// insertion point of any new MLIR operations related to the binding of +/// arguments of a loop wrapper operation. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] op - owner operation of the block arguments to bind. +/// \param [in] args - entry block arguments information for the given +/// operation. +static void bindEntryBlockArgs(lower::AbstractConverter &converter, + mlir::omp::BlockArgOpenMPOpInterface op, + const EntryBlockArgs &args) { + assert(op != nullptr && "invalid block argument-defining operation"); + assert(args.isValid() && "invalid args"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + auto bindSingleMapLike = [&converter, + &firOpBuilder](const semantics::Symbol &sym, + const mlir::BlockArgument &arg) { + // Clones the `bounds` placing them inside the entry block and returns + // them. + auto cloneBound = [&](mlir::Value bound) { + if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { + mlir::Operation *clonedOp = firOpBuilder.clone(*bound.getDefiningOp()); + return clonedOp->getResult(0); + } + TODO(converter.getCurrentLocation(), + "target map-like clause operand unsupported bound type"); + }; + + auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { + llvm::SmallVector clonedBounds; + llvm::transform(bounds, std::back_inserter(clonedBounds), + [&](mlir::Value bound) { return cloneBound(bound); }); + return clonedBounds; + }; + + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(sym); + auto refType = mlir::dyn_cast(arg.getType()); + if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { + converter.bindSymbol(sym, arg); + } else { + extVal.match( + [&](const fir::BoxValue &v) { + converter.bindSymbol(sym, + fir::BoxValue(arg, cloneBounds(v.getLBounds()), + v.getExplicitParameters(), + v.getExplicitExtents())); + }, + [&](const fir::MutableBoxValue &v) { + converter.bindSymbol( + sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), + v.getMutableProperties())); + }, + [&](const fir::ArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()), + v.getSourceBox())); + }, + [&](const fir::CharArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), + cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()))); + }, + [&](const fir::CharBoxValue &v) { + converter.bindSymbol( + sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); + }, + [&](const fir::UnboxedValue &v) { converter.bindSymbol(sym, arg); }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "target map clause operand unsupported type"); + }); + } + }; + + auto bindMapLike = + [&bindSingleMapLike](llvm::ArrayRef syms, + llvm::ArrayRef args) { + // Structure component symbols don't have bindings, and can only be + // explicitly mapped individually. If a member is captured implicitly + // we map the entirety of the derived type when we find its symbol. + llvm::SmallVector processedSyms; + llvm::copy_if(syms, std::back_inserter(processedSyms), + [](auto *sym) { return !sym->owner().IsDerivedType(); }); + + for (auto [sym, arg] : llvm::zip_equal(processedSyms, args)) + bindSingleMapLike(*sym, arg); + }; + + auto bindPrivateLike = [&converter, &firOpBuilder]( + llvm::ArrayRef syms, + llvm::ArrayRef vars, + llvm::ArrayRef args) { + llvm::SmallVector processedSyms; + for (auto *sym : syms) { + if (const auto *commonDet = + sym->detailsIf()) { + llvm::transform(commonDet->objects(), std::back_inserter(processedSyms), + [&](const auto &mem) { return &*mem; }); + } else { + processedSyms.push_back(sym); + } + } + + for (auto [sym, var, arg] : llvm::zip_equal(processedSyms, vars, args)) + converter.bindSymbol( + *sym, + hlfir::translateToExtendedValue( + var.getLoc(), firOpBuilder, hlfir::Entity{arg}, + /*contiguousHint=*/ + evaluate::IsSimplyContiguous(*sym, converter.getFoldingContext())) + .first); + }; + + // Process in clause name alphabetical order to match block arguments order. + bindPrivateLike(args.inReduction.syms, args.inReduction.vars, + op.getInReductionBlockArgs()); + bindMapLike(args.map.syms, op.getMapBlockArgs()); + bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs()); + bindPrivateLike(args.reduction.syms, args.reduction.vars, + op.getReductionBlockArgs()); + bindPrivateLike(args.taskReduction.syms, args.taskReduction.vars, + op.getTaskReductionBlockArgs()); + bindMapLike(args.useDeviceAddr.syms, op.getUseDeviceAddrBlockArgs()); + bindMapLike(args.useDevicePtr.syms, op.getUseDevicePtrBlockArgs()); +} + +/// Get the list of base values that the specified map-like variables point to. +/// +/// This function must be kept in sync with changes to the `createMapInfoOp` +/// utility function, since it must take into account the potential introduction +/// of levels of indirection (i.e. intermediate ops). +/// +/// \param [in] vars - list of values passed to map-like clauses, returned +/// by an `omp.map.info` operation. +/// \param [out] baseOps - populated with the `var_ptr` values of the +/// corresponding defining operations. +static void +extractMappedBaseValues(llvm::ArrayRef vars, + llvm::SmallVectorImpl &baseOps) { + llvm::transform(vars, std::back_inserter(baseOps), [](mlir::Value map) { + auto mapInfo = map.getDefiningOp(); + assert(mapInfo && "expected all map vars to be defined by omp.map.info"); + + mlir::Value varPtr = mapInfo.getVarPtr(); + if (auto boxAddr = varPtr.getDefiningOp()) + return boxAddr.getVal(); + + return varPtr; + }); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -226,55 +418,41 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, return storeOp; } -// This helper function implements the functionality of "promoting" -// non-CPTR arguments of use_device_ptr to use_device_addr -// arguments (automagic conversion of use_device_ptr -> -// use_device_addr in these cases). The way we do so currently is -// through the shuffling of operands from the devicePtrOperands to -// deviceAddrOperands where neccesary and re-organizing the types, -// locations and symbols to maintain the correct ordering of ptr/addr -// input -> BlockArg. +// This helper function implements the functionality of "promoting" non-CPTR +// arguments of use_device_ptr to use_device_addr arguments (automagic +// conversion of use_device_ptr -> use_device_addr in these cases). The way we +// do so currently is through the shuffling of operands from the +// devicePtrOperands to deviceAddrOperands, as well as the types, locations and +// symbols. // -// This effectively implements some deprecated OpenMP functionality -// that some legacy applications unfortunately depend on -// (deprecated in specification version 5.2): +// This effectively implements some deprecated OpenMP functionality that some +// legacy applications unfortunately depend on (deprecated in specification +// version 5.2): // -// "If a list item in a use_device_ptr clause is not of type C_PTR, -// the behavior is as if the list item appeared in a use_device_addr -// clause. Support for such list items in a use_device_ptr clause -// is deprecated." +// "If a list item in a use_device_ptr clause is not of type C_PTR, the behavior +// is as if the list item appeared in a use_device_addr clause. Support for +// such list items in a use_device_ptr clause is deprecated." static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( llvm::SmallVectorImpl &useDeviceAddrVars, + llvm::SmallVectorImpl &useDeviceAddrSyms, llvm::SmallVectorImpl &useDevicePtrVars, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) { - auto moveElementToBack = [](size_t idx, auto &vector) { - auto *iter = std::next(vector.begin(), idx); - vector.push_back(*iter); - vector.erase(iter); - }; - + llvm::SmallVectorImpl &useDevicePtrSyms) { // Iterate over our use_device_ptr list and shift all non-cptr arguments into // use_device_addr. - for (auto *it = useDevicePtrVars.begin(); it != useDevicePtrVars.end();) { - if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - useDeviceAddrVars.push_back(*it); - // We have to shuffle the symbols around as well, to maintain - // the correct Input -> BlockArg for use_device_ptr/use_device_addr. - // NOTE: However, as map's do not seem to be included currently - // this isn't as pertinent, but we must try to maintain for - // future alterations. I believe the reason they are not currently - // is that the BlockArg assign/lowering needs to be extended - // to a greater set of types. - auto idx = std::distance(useDevicePtrVars.begin(), it); - moveElementToBack(idx, useDeviceTypes); - moveElementToBack(idx, useDeviceLocs); - moveElementToBack(idx, useDeviceSymbols); - it = useDevicePtrVars.erase(it); + auto *varIt = useDevicePtrVars.begin(); + auto *symIt = useDevicePtrSyms.begin(); + while (varIt != useDevicePtrVars.end()) { + if (fir::isa_builtin_cptr_type(fir::unwrapRefType(varIt->getType()))) { + ++varIt; + ++symIt; continue; } - ++it; + + useDeviceAddrVars.push_back(*varIt); + useDeviceAddrSyms.push_back(*symIt); + + varIt = useDevicePtrVars.erase(varIt); + symIt = useDevicePtrSyms.erase(symIt); } } @@ -380,14 +558,14 @@ getDeclareTargetFunctionDevice( /// \param [in] converter - PFT to MLIR conversion interface. /// \param [in] loc - location. /// \param [in] args - symbols of induction variables. -/// \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper +/// \param [in] wrapperArgs - list of parent loop wrappers and their associated /// entry block arguments. -/// \param [in] wrapperArgs - entry block arguments of parent loop wrappers. -static void -genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, llvm::ArrayRef args, - llvm::ArrayRef wrapperSyms = {}, - llvm::ArrayRef wrapperArgs = {}) { +static void genLoopVars( + mlir::Operation *op, lower::AbstractConverter &converter, + mlir::Location &loc, llvm::ArrayRef args, + llvm::ArrayRef< + std::pair> + wrapperArgs = {}) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto ®ion = op->getRegion(0); @@ -401,8 +579,8 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, // Bind the entry block arguments of parent wrappers to the corresponding // symbols. - for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs)) - converter.bindSymbol(*arg, prv); + for (auto [argGeneratingOp, args] : wrapperArgs) + bindEntryBlockArgs(converter, argGeneratingOp, args); // The argument is not currently in memory, so make a temporary for the // argument, and store it there, then bind that location to the argument. @@ -415,22 +593,47 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, firOpBuilder.setInsertionPointAfter(storeOp); } -static void -genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef reductionArgs, - llvm::ArrayRef reductionTypes) { +/// Create an entry block for the given region, including the clause-defined +/// arguments specified. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] args - entry block arguments information for the given +/// operation. +/// \param [in] region - Empty region in which to create the entry block. +static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, + const EntryBlockArgs &args, + mlir::Region ®ion) { + assert(args.isValid() && "invalid args"); + assert(region.empty() && "non-empty region"); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - llvm::SmallVector blockArgLocs(reductionArgs.size(), loc); - mlir::Block *entryBlock = firOpBuilder.createBlock( - &op->getRegion(0), {}, reductionTypes, blockArgLocs); + llvm::SmallVector types; + llvm::SmallVector locs; + unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + + args.priv.vars.size() + args.reduction.vars.size() + + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size(); + types.reserve(numVars); + locs.reserve(numVars); + + auto extractTypeLoc = [&types, &locs](llvm::ArrayRef vals) { + llvm::transform(vals, std::back_inserter(types), + [](mlir::Value v) { return v.getType(); }); + llvm::transform(vals, std::back_inserter(locs), + [](mlir::Value v) { return v.getLoc(); }); + }; - // Bind the reduction arguments to their block arguments. - for (auto [arg, prv] : - llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { - converter.bindSymbol(*arg, prv); - } + // Populate block arguments in clause name alphabetical order to match + // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.inReduction.vars); + extractTypeLoc(args.map.vars); + extractTypeLoc(args.priv.vars); + extractTypeLoc(args.reduction.vars); + extractTypeLoc(args.taskReduction.vars); + extractTypeLoc(args.useDeviceAddr.vars); + extractTypeLoc(args.useDevicePtr.vars); + + return firOpBuilder.createBlock(®ion, {}, types, locs); } static void @@ -458,42 +661,6 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, declareTargetOp.setDeclareTarget(deviceType, captureClause); } -/// For an operation that takes `omp.private` values as region args, this util -/// merges the private vars info into the region arguments list. -/// -/// \tparam OMPOP - the OpenMP op that takes `omp.private` inputs. -/// \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type -/// or mlir::Location fields of the private var list. -/// -/// \param [in] op - the op accepting `omp.private` inputs. -/// \param [in] currentList - the current list of region info that we -/// want to merge private info with. For example this could be the list of types -/// or locations of previous arguments to \op's region. -/// \param [in] infoAccessor - for a private variable, this returns the -/// data we want to merge: type or location. -/// \param [out] allRegionArgsInfo - the merged list of region info. -/// \param [in] addBeforePrivate - `true` if the passed information goes before -/// private information. -template -static void -mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef currentList, - llvm::function_ref infoAccessor, - llvm::SmallVectorImpl &allRegionArgsInfo, - bool addBeforePrivate) { - mlir::OperandRange privateVars = op.getPrivateVars(); - - if (addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); - - llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo), - infoAccessor); - - if (!addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); -} - //===----------------------------------------------------------------------===// // Op body generation helper structures and functions //===----------------------------------------------------------------------===// @@ -713,94 +880,16 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, marker->erase(); } -void mapBodySymbols(lower::AbstractConverter &converter, mlir::Region ®ion, - llvm::ArrayRef mapSyms) { - assert(region.hasOneBlock() && "target must have single region"); - mlir::Block ®ionBlock = region.front(); - // Clones the `bounds` placing them inside the target region and returns them. - auto cloneBound = [&](mlir::Value bound) { - if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { - mlir::Operation *clonedOp = bound.getDefiningOp()->clone(); - regionBlock.push_back(clonedOp); - return clonedOp->getResult(0); - } - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported bound type"); - }; - - auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { - llvm::SmallVector clonedBounds; - for (mlir::Value bound : bounds) - clonedBounds.emplace_back(cloneBound(bound)); - return clonedBounds; - }; - - // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - const mlir::BlockArgument &arg = region.getArgument(argIndex); - // Avoid capture of a reference to a structured binding. - const semantics::Symbol *sym = argSymbol; - // Structure component symbols don't have bindings. - if (sym->owner().IsDerivedType()) - continue; - fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); - auto refType = mlir::dyn_cast(arg.getType()); - if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { - converter.bindSymbol(*argSymbol, arg); - } else { - extVal.match( - [&](const fir::BoxValue &v) { - converter.bindSymbol(*sym, - fir::BoxValue(arg, cloneBounds(v.getLBounds()), - v.getExplicitParameters(), - v.getExplicitExtents())); - }, - [&](const fir::MutableBoxValue &v) { - converter.bindSymbol( - *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), - v.getMutableProperties())); - }, - [&](const fir::ArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()), - v.getSourceBox())); - }, - [&](const fir::CharArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), - cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()))); - }, - [&](const fir::CharBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); - }, - [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported type"); - }); - } - } -} - static void genBodyOfTargetDataOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetDataOp &dataOp, - llvm::ArrayRef useDeviceSymbols, - llvm::ArrayRef useDeviceLocs, - llvm::ArrayRef useDeviceTypes, + mlir::omp::TargetDataOp &dataOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(useDeviceTypes.size() == useDeviceLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = dataOp.getRegion(); - firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); - mapBodySymbols(converter, region, useDeviceSymbols); + genEntryBlock(converter, args, dataOp.getRegion()); + bindEntryBlockArgs(converter, dataOp, args); // Insert dummy instruction to remember the insertion position. The // marker will be deleted by clean up passes since there are no uses. @@ -841,19 +930,25 @@ static void genBodyOfTargetDataOp( // This is for utilisation with TargetOp. static void genIntermediateCommonBlockAccessors( Fortran::lower::AbstractConverter &converter, - const mlir::Location ¤tLocation, mlir::Region ®ion, + const mlir::Location ¤tLocation, + llvm::ArrayRef mapBlockArgs, llvm::ArrayRef mapSyms) { - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - if (auto *details = - argSymbol->detailsIf()) { - for (auto obj : details->objects()) { - auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( - converter, currentLocation, *obj, region.getArgument(argIndex)); - fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); - fir::ExtendedValue targetCBExv = - getExtendedValue(sexv, targetCBMemberBind); - converter.bindSymbol(*obj, targetCBExv); - } + // Iterate over the symbol list, which will be shorter than the list of + // arguments if new entry block arguments were introduced to implicitly map + // outside values used by the bounds cloned into the target region. In that + // case, the additional block arguments do not need processing here. + for (auto [mapSym, mapArg] : llvm::zip_first(mapSyms, mapBlockArgs)) { + auto *details = mapSym->detailsIf(); + if (!details) + continue; + + for (auto obj : details->objects()) { + auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( + converter, currentLocation, *obj, mapArg); + fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); + fir::ExtendedValue targetCBExv = + getExtendedValue(sexv, targetCBMemberBind); + converter.bindSymbol(*obj, targetCBExv); } } } @@ -863,47 +958,15 @@ static void genIntermediateCommonBlockAccessors( static void genBodyOfTargetOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetOp &targetOp, - llvm::ArrayRef mapSyms, - llvm::ArrayRef mapSymLocs, - llvm::ArrayRef mapSymTypes, + mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item, DataSharingProcessor &dsp) { - assert(mapSymTypes.size() == mapSymLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = targetOp.getRegion(); + auto argIface = llvm::cast(*targetOp); - llvm::SmallVector allRegionArgTypes; - llvm::SmallVector allRegionArgLocs; - mergePrivateVarsInfo(targetOp, mapSymTypes, - llvm::function_ref{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/true); - - mergePrivateVarsInfo(targetOp, mapSymLocs, - llvm::function_ref{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/true); - - mlir::Block *regionBlock = firOpBuilder.createBlock( - ®ion, {}, allRegionArgTypes, allRegionArgLocs); - - mapBodySymbols(converter, region, mapSyms); - - for (auto [argIndex, argSymbol] : - llvm::enumerate(dsp.getAllSymbolsToPrivatize())) { - argIndex = mapSyms.size() + argIndex; - - const mlir::BlockArgument &arg = region.getArgument(argIndex); - converter.bindSymbol(*argSymbol, - hlfir::translateToExtendedValue( - currentLocation, firOpBuilder, hlfir::Entity{arg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *argSymbol, converter.getFoldingContext())) - .first); - } + mlir::Region ®ion = targetOp.getRegion(); + mlir::Block *entryBlock = genEntryBlock(converter, args, region); + bindEntryBlockArgs(converter, targetOp, args); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -916,11 +979,11 @@ static void genBodyOfTargetOp( mlir::Operation *valOp = val.getDefiningOp(); if (mlir::isMemoryEffectFree(valOp)) { mlir::Operation *clonedOp = valOp->clone(); - regionBlock->push_front(clonedOp); - val.replaceUsesWithIf( - clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); + entryBlock->push_front(clonedOp); + val.replaceUsesWithIf(clonedOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); } else { auto savedIP = firOpBuilder.getInsertionPoint(); firOpBuilder.setInsertionPointAfter(valOp); @@ -941,18 +1004,23 @@ static void genBodyOfTargetOp( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT), mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType()); + // Get the index of the first non-map argument before modifying mapVars, + // then append an element to mapVars and an associated entry block + // argument at that index. + unsigned insertIndex = + argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs(); targetOp.getMapVarsMutable().append(mapOp); + mlir::Value clonedValArg = region.insertArgument( + insertIndex, copyVal.getType(), copyVal.getLoc()); - mlir::Value clonedValArg = - region.addArgument(copyVal.getType(), copyVal.getLoc()); - firOpBuilder.setInsertionPointToStart(regionBlock); + firOpBuilder.setInsertionPointToStart(entryBlock); auto loadOp = firOpBuilder.create(clonedValArg.getLoc(), clonedValArg); - val.replaceUsesWithIf( - loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); - firOpBuilder.setInsertionPoint(regionBlock, savedIP); + val.replaceUsesWithIf(loadOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); + firOpBuilder.setInsertionPoint(entryBlock, savedIP); } } valuesDefinedAbove.clear(); @@ -980,14 +1048,14 @@ static void genBodyOfTargetOp( firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); // If we map a common block using it's symbol e.g. map(tofrom: /common_block/) - // and accessing it's members within the target region, there is a large + // and accessing its members within the target region, there is a large // chance we will end up with uses external to the region accessing the common // resolve these, we do so by generating new common block member accesses // within the region, binding them to the member symbol for the scope of the // region so that subsequent code generation within the region will utilise // our new member accesses we have created. - genIntermediateCommonBlockAccessors(converter, currentLocation, region, - mapSyms); + genIntermediateCommonBlockAccessors( + converter, currentLocation, argIface.getMapBlockArgs(), args.map.syms); if (ConstructQueue::const_iterator next = std::next(item); next != queue.end()) { @@ -1013,7 +1081,7 @@ static OpTy genOpWithBody(const OpWithBodyGenInfo &info, template static OpTy genWrapperOp(lower::AbstractConverter &converter, mlir::Location loc, const ClauseOpsTy &clauseOps, - llvm::ArrayRef blockArgTypes) { + const EntryBlockArgs &args) { static_assert( OpTy::template hasTrait(), "expected a loop wrapper"); @@ -1023,9 +1091,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter, auto op = firOpBuilder.create(loc, clauseOps); // Create entry block with arguments. - llvm::SmallVector locs(blockArgTypes.size(), loc); - firOpBuilder.createBlock(&op.getRegion(), /*insertPt=*/{}, blockArgTypes, - locs); + genEntryBlock(converter, args, op.getRegion()); firOpBuilder.setInsertionPoint( lower::genOpenMPTerminator(firOpBuilder, op, loc)); @@ -1105,39 +1171,38 @@ static void genParallelClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, mlir::omp::ParallelOperands &clauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); cp.processNumThreads(stmtCtx, clauseOps); cp.processProcBind(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); } static void genSectionsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List &clauses, mlir::Location loc, mlir::omp::SectionsOperands &clauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processNowait(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); // TODO Support delayed privatization. } -static void genSimdClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - const List &clauses, mlir::Location loc, - mlir::omp::SimdOperands &clauseOps) { +static void genSimdClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::SimdOperands &clauseOps, + llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAligned(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); cp.processNontemporal(clauseOps); cp.processOrder(clauseOps); - cp.processReduction(loc, clauseOps); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); @@ -1160,24 +1225,16 @@ static void genTargetClauses( lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, bool processHostOnlyClauses, mlir::omp::TargetOperands &clauseOps, - llvm::SmallVectorImpl &mapSyms, - llvm::SmallVectorImpl &mapLocs, - llvm::SmallVectorImpl &mapTypes, - llvm::SmallVectorImpl &deviceAddrSyms, - llvm::SmallVectorImpl &deviceAddrLocs, - llvm::SmallVectorImpl &deviceAddrTypes, - llvm::SmallVectorImpl &devicePtrSyms, - llvm::SmallVectorImpl &devicePtrLocs, - llvm::SmallVectorImpl &devicePtrTypes) { + llvm::SmallVectorImpl &hasDeviceAddrSyms, + llvm::SmallVectorImpl &isDevicePtrSyms, + llvm::SmallVectorImpl &mapSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); - cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, - deviceAddrSyms); + cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, - devicePtrSyms); - cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes); + cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); if (processHostOnlyClauses) cp.processNowait(clauseOps); @@ -1197,32 +1254,26 @@ static void genTargetDataClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, mlir::omp::TargetDataOperands &clauseOps, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSyms) { + llvm::SmallVectorImpl &useDeviceAddrSyms, + llvm::SmallVectorImpl &useDevicePtrSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDevice(stmtCtx, clauseOps); cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); cp.processMap(loc, stmtCtx, clauseOps); - cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - cp.processUseDevicePtr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); + cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceAddrSyms); + cp.processUseDevicePtr(stmtCtx, clauseOps, useDevicePtrSyms); // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy // code may still depend on this functionality, so we should support it // in some manner. We do so currently by simply shifting non-cptr operands - // from the use_device_ptr list into the front of the use_device_addr list - // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and - // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg - // ordering. + // from the use_device_ptr lists into the use_device_addr lists. // TODO: Perhaps create a user provideable compiler option that will // re-introduce a hard-error rather than a warning in these cases. promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - clauseOps.useDeviceAddrVars, clauseOps.useDevicePtrVars, useDeviceTypes, - useDeviceLocs, useDeviceSyms); + clauseOps.useDeviceAddrVars, useDeviceAddrSyms, + clauseOps.useDevicePtrVars, useDevicePtrSyms); } static void genTargetEnterExitUpdateDataClauses( @@ -1300,13 +1351,12 @@ static void genWsloopClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, mlir::omp::WsloopOperands &clauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processNowait(clauseOps); cp.processOrder(clauseOps); cp.processOrdered(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); cp.processTODO( @@ -1369,21 +1419,18 @@ genFlushOp(lower::AbstractConverter &converter, lower::SymMap &symTable, converter.getCurrentLocation(), operandRange); } -static mlir::omp::LoopNestOp -genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, ConstructQueue::const_iterator item, - mlir::omp::LoopNestOperands &clauseOps, - llvm::ArrayRef iv, - llvm::ArrayRef wrapperSyms, - llvm::ArrayRef wrapperArgs, - llvm::omp::Directive directive, DataSharingProcessor &dsp) { - assert(wrapperSyms.size() == wrapperArgs.size() && - "Number of symbols and wrapper block arguments must match"); - +static mlir::omp::LoopNestOp genLoopNestOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item, mlir::omp::LoopNestOperands &clauseOps, + llvm::ArrayRef iv, + llvm::ArrayRef< + std::pair> + wrapperArgs, + llvm::omp::Directive directive, DataSharingProcessor &dsp) { auto ivCallback = [&](mlir::Operation *op) { - genLoopVars(op, converter, loc, iv, wrapperSyms, wrapperArgs); + genLoopVars(op, converter, loc, iv, wrapperArgs); return llvm::SmallVector(iv); }; @@ -1455,83 +1502,26 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, mlir::omp::ParallelOperands &clauseOps, - llvm::ArrayRef reductionSyms, - llvm::ArrayRef reductionTypes, - DataSharingProcessor *dsp, bool isComposite = false) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); - return llvm::SmallVector(reductionSyms); + const EntryBlockArgs &args, DataSharingProcessor *dsp, + bool isComposite = false) { + auto genRegionEntryCB = [&](mlir::Operation *op) { + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast(op), args); + return llvm::to_vector(llvm::concat( + args.priv.syms, args.reduction.syms)); }; + assert((!enableDelayedPrivatization || dsp) && + "expected valid DataSharingProcessor"); OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_parallel) .setClauses(&item->clauses) - .setGenRegionEntryCb(reductionCallback) - .setGenSkeletonOnly(isComposite); - - if (!enableDelayedPrivatization) { - auto parallelOp = - genOpWithBody(genInfo, queue, item, clauseOps); - parallelOp.setComposite(isComposite); - return parallelOp; - } + .setGenRegionEntryCb(genRegionEntryCB) + .setGenSkeletonOnly(isComposite) + .setDataSharingProcessor(dsp); - assert(dsp && "expected valid DataSharingProcessor"); - auto genRegionEntryCB = [&](mlir::Operation *op) { - auto parallelOp = llvm::cast(op); - - llvm::SmallVector reductionLocs( - clauseOps.reductionVars.size(), loc); - - llvm::SmallVector allRegionArgTypes; - mergePrivateVarsInfo(parallelOp, reductionTypes, - llvm::function_ref{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/false); - - llvm::SmallVector allRegionArgLocs; - mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs), - llvm::function_ref{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/false); - - mlir::Region ®ion = parallelOp.getRegion(); - firOpBuilder.createBlock(®ion, /*insertPt=*/{}, allRegionArgTypes, - allRegionArgLocs); - - llvm::SmallVector allSymbols( - dsp->getDelayedPrivSymbols()); - allSymbols.append(reductionSyms.begin(), reductionSyms.end()); - - unsigned argIdx = 0; - for (const semantics::Symbol *arg : allSymbols) { - auto bind = [&](const semantics::Symbol *sym) { - mlir::BlockArgument blockArg = region.getArgument(argIdx); - ++argIdx; - converter.bindSymbol(*sym, - hlfir::translateToExtendedValue( - loc, firOpBuilder, hlfir::Entity{blockArg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *sym, converter.getFoldingContext())) - .first); - }; - - if (const auto *commonDet = - arg->detailsIf()) { - for (const auto &mem : commonDet->objects()) - bind(&*mem); - } else - bind(arg); - } - - return allSymbols; - }; - - genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(dsp); auto parallelOp = genOpWithBody(genInfo, queue, item, clauseOps); parallelOp.setComposite(isComposite); @@ -1547,11 +1537,10 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, const parser::OmpSectionBlocks §ionBlocks) { - llvm::SmallVector reductionTypes; - llvm::SmallVector reductionSyms; mlir::omp::SectionsOperands clauseOps; + llvm::SmallVector reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, - reductionTypes, reductionSyms); + reductionSyms); auto &builder = converter.getFirOpBuilder(); @@ -1584,15 +1573,20 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // SECTIONS construct. auto sectionsOp = builder.create(loc, clauseOps); - // create entry block with reduction variables as arguments - llvm::SmallVector blockArgLocs(reductionSyms.size(), loc); - builder.createBlock(§ionsOp->getRegion(0), {}, reductionTypes, - blockArgLocs); + // Create entry block with reduction variables as arguments. + EntryBlockArgs args; + // TODO: Add private syms and vars. + args.reduction.syms = reductionSyms; + args.reduction.vars = clauseOps.reductionVars; + + genEntryBlock(converter, args, sectionsOp.getRegion()); mlir::Operation *terminator = lower::genOpenMPTerminator(builder, sectionsOp, loc); auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast(op), args); return reductionSyms; }; @@ -1686,14 +1680,11 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, .getIsTargetDevice(); mlir::omp::TargetOperands clauseOps; - llvm::SmallVector mapSyms, devicePtrSyms, - deviceAddrSyms; - llvm::SmallVector mapLocs, devicePtrLocs, deviceAddrLocs; - llvm::SmallVector mapTypes, devicePtrTypes, deviceAddrTypes; + llvm::SmallVector mapSyms, isDevicePtrSyms, + hasDeviceAddrSyms; genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, mapSyms, mapLocs, - mapTypes, deviceAddrSyms, deviceAddrLocs, deviceAddrTypes, - devicePtrSyms, devicePtrLocs, devicePtrTypes); + processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, + isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -1800,15 +1791,24 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, clauseOps.mapVars.push_back(mapOp); mapSyms.push_back(&sym); - mapLocs.push_back(baseOp.getLoc()); - mapTypes.push_back(baseOp.getType()); } }; lower::pft::visitAllSymbols(eval, captureImplicitMap); auto targetOp = firOpBuilder.create(loc, clauseOps); - genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, mapSyms, - mapLocs, mapTypes, loc, queue, item, dsp); + + llvm::SmallVector mapBaseValues; + extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); + + EntryBlockArgs args; + // TODO: Add in_reduction syms and vars. + args.map.syms = mapSyms; + args.map.vars = mapBaseValues; + args.priv.syms = dsp.getDelayedPrivSymbols(); + args.priv.vars = clauseOps.privateVars; + + genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, + queue, item, dsp); return targetOp; } @@ -1820,18 +1820,28 @@ genTargetDataOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; mlir::omp::TargetDataOperands clauseOps; - llvm::SmallVector useDeviceTypes; - llvm::SmallVector useDeviceLocs; - llvm::SmallVector useDeviceSyms; + llvm::SmallVector useDeviceAddrSyms, + useDevicePtrSyms; genTargetDataClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); + clauseOps, useDeviceAddrSyms, useDevicePtrSyms); auto targetDataOp = converter.getFirOpBuilder().create(loc, clauseOps); - genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, - useDeviceSyms, useDeviceLocs, useDeviceTypes, loc, - queue, item); + + llvm::SmallVector useDeviceAddrBaseValues, + useDevicePtrBaseValues; + extractMappedBaseValues(clauseOps.useDeviceAddrVars, useDeviceAddrBaseValues); + extractMappedBaseValues(clauseOps.useDevicePtrVars, useDevicePtrBaseValues); + + EntryBlockArgs args; + args.useDeviceAddr.syms = useDeviceAddrSyms; + args.useDeviceAddr.vars = useDeviceAddrBaseValues; + args.useDevicePtr.syms = useDevicePtrSyms; + args.useDevicePtr.vars = useDevicePtrBaseValues; + + genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, args, + loc, queue, item); return targetDataOp; } @@ -1953,22 +1963,20 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter, /*shouldCollectPreDeterminedSymbols=*/true, enableDelayedPrivatizationStaging, &symTable); dsp.processStep1(&distributeClauseOps); - llvm::SmallVector privateVarTypes{}; - - for (mlir::Value privateVar : distributeClauseOps.privateVars) - privateVarTypes.push_back(privateVar.getType()); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); + EntryBlockArgs distributeArgs; + distributeArgs.priv.syms = dsp.getDelayedPrivSymbols(); + distributeArgs.priv.vars = distributeClauseOps.privateVars; auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, privateVarTypes); + converter, loc, distributeClauseOps, distributeArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, dsp.getDelayedPrivSymbols(), - distributeOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{distributeOp, distributeArgs}}, llvm::omp::Directive::OMPD_distribute, dsp); } @@ -1981,10 +1989,9 @@ static void genStandaloneDo(lower::AbstractConverter &converter, lower::StatementContext stmtCtx; mlir::omp::WsloopOperands wsloopClauseOps; - llvm::SmallVector reductionSyms; - llvm::SmallVector reductionTypes; + llvm::SmallVector wsloopReductionSyms; genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - wsloopClauseOps, reductionTypes, reductionSyms); + wsloopClauseOps, wsloopReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -1997,13 +2004,15 @@ static void genStandaloneDo(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, reductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, reductionSyms, - wsloopOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_do, dsp); } @@ -2016,21 +2025,27 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; - mlir::omp::ParallelOperands clauseOps; - llvm::SmallVector reductionSyms; - llvm::SmallVector reductionTypes; - genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps, - reductionTypes, reductionSyms); + mlir::omp::ParallelOperands parallelClauseOps; + llvm::SmallVector parallelReductionSyms; + genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, + parallelClauseOps, parallelReductionSyms); std::optional dsp; if (enableDelayedPrivatization) { dsp.emplace(converter, semaCtx, item->clauses, eval, lower::omp::isLastItemInQueue(item, queue), /*useDelayedPrivatization=*/true, &symTable); - dsp->processStep1(&clauseOps); + dsp->processStep1(¶llelClauseOps); } - genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, clauseOps, - reductionSyms, reductionTypes, + + EntryBlockArgs parallelArgs; + if (dsp) + parallelArgs.priv.syms = dsp->getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; + genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, + parallelClauseOps, parallelArgs, enableDelayedPrivatization ? &dsp.value() : nullptr); } @@ -2041,7 +2056,9 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -2054,13 +2071,15 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, - /*wrapperSyms=*/{}, simdOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{simdOp, simdArgs}}, llvm::omp::Directive::OMPD_simd, dsp); } @@ -2093,19 +2112,21 @@ static void genCompositeDistributeParallelDo( // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; - llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - parallelClauseOps, parallelReductionTypes, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, /*useDelayedPrivatization=*/true, &symTable); dsp.processStep1(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2114,9 +2135,8 @@ static void genCompositeDistributeParallelDo( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; @@ -2124,27 +2144,23 @@ static void genCompositeDistributeParallelDo( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector( - llvm::concat(distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, doItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do, dsp); } @@ -2164,19 +2180,21 @@ static void genCompositeDistributeParallelDoSimd( // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; - llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - parallelClauseOps, parallelReductionTypes, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, /*useDelayedPrivatization=*/true, &symTable); dsp.processStep1(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2185,12 +2203,13 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; @@ -2198,32 +2217,33 @@ static void genCompositeDistributeParallelDoSimd( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector(llvm::concat( - distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, + {wsloopOp, wsloopArgs}, + {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do_simd, dsp); } @@ -2246,7 +2266,9 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loc, distributeClauseOps); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2262,26 +2284,23 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - // TODO: Add omp.distribute private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector( - llvm::concat(distributeOp.getRegion().getArguments(), - simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_simd, dsp); } @@ -2300,12 +2319,13 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, // Clause processing. mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2321,25 +2341,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol and block argument order match, so that the symbol-value - // bindings created are correct. - // TODO: Add omp.wsloop private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector(llvm::concat( - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs, + loopNestClauseOps, iv, + {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_do_simd, dsp); } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index 9da15ba303a47..6b98ea3d0615b 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -722,7 +722,7 @@ void ReductionProcessor::addDeclareReduction( llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl *reductionSymbols) { + llvm::SmallVectorImpl &reductionSymbols) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); if (std::get>( @@ -753,8 +753,7 @@ void ReductionProcessor::addDeclareReduction( fir::FirOpBuilder &builder = converter.getFirOpBuilder(); for (const Object &object : objectList) { const semantics::Symbol *symbol = object.sym(); - if (reductionSymbols) - reductionSymbols->push_back(symbol); + reductionSymbols.push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); mlir::Type eleType; auto refType = mlir::dyn_cast_or_null(symVal.getType()); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 0ed5782e5da1b..5f4d742b62cb1 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -126,8 +126,7 @@ class ReductionProcessor { llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl *reductionSymbols = - nullptr); + llvm::SmallVectorImpl &reductionSymbols); }; template diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 47bc12e1b8a03..a09d91540ec22 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -264,9 +264,7 @@ void insertChildMapInfoIntoParent( std::map> &parentMemberIndices, llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl &mapSyms, - llvm::SmallVectorImpl *mapSymTypes, - llvm::SmallVectorImpl *mapSymLocs) { + llvm::SmallVectorImpl &mapSyms) { for (auto indices : parentMemberIndices) { bool parentExists = false; size_t parentIdx; @@ -322,11 +320,6 @@ void insertChildMapInfoIntoParent( mapOperands.push_back(mapOp); mapSyms.push_back(indices.first); - - if (mapSymTypes) - mapSymTypes->push_back(mapOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(mapOp.getLoc()); } } } diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 658d062e67b27..4a569cd64355d 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -78,9 +78,7 @@ void insertChildMapInfoIntoParent( std::map> &parentMemberIndices, llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl &mapSyms, - llvm::SmallVectorImpl *mapSymTypes, - llvm::SmallVectorImpl *mapSymLocs); + llvm::SmallVectorImpl &mapSyms); mlir::Type getLoopVarType(lower::AbstractConverter &converter, std::size_t loopVarTypeSize);