@@ -366,10 +366,29 @@ getDeclareTargetFunctionDevice(
366366 return std::nullopt ;
367367}
368368
369- static llvm::SmallVector<const Fortran::semantics::Symbol *>
369+ // / Set up the entry block of the given `omp.loop_nest` operation, adding a
370+ // / block argument for each loop induction variable and allocating and
371+ // / initializing a private value to hold each of them.
372+ // /
373+ // / This function can also bind the symbols of any variables that should match
374+ // / block arguments on parent loop wrapper operations attached to the same
375+ // / loop. This allows the introduction of any necessary `hlfir.declare`
376+ // / operations inside of the entry block of the `omp.loop_nest` operation and
377+ // / not directly under any of the wrappers, which would invalidate them.
378+ // /
379+ // / \param [in] op - the loop nest operation.
380+ // / \param [in] converter - PFT to MLIR conversion interface.
381+ // / \param [in] loc - location.
382+ // / \param [in] args - symbols of induction variables.
383+ // / \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper
384+ // / entry block arguments.
385+ // / \param [in] wrapperArgs - entry block arguments of parent loop wrappers.
386+ static void
370387genLoopVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
371388 mlir::Location &loc,
372- llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
389+ llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
390+ llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
391+ llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
373392 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
374393 auto ®ion = op->getRegion (0 );
375394
@@ -380,6 +399,12 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
380399 llvm::SmallVector<mlir::Type> tiv (args.size (), loopVarType);
381400 llvm::SmallVector<mlir::Location> locs (args.size (), loc);
382401 firOpBuilder.createBlock (®ion, {}, tiv, locs);
402+
403+ // Bind the entry block arguments of parent wrappers to the corresponding
404+ // symbols.
405+ for (auto [arg, prv] : llvm::zip_equal (wrapperSyms, wrapperArgs))
406+ converter.bindSymbol (*arg, prv);
407+
383408 // The argument is not currently in memory, so make a temporary for the
384409 // argument, and store it there, then bind that location to the argument.
385410 mlir::Operation *storeOp = nullptr ;
@@ -389,7 +414,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
389414 createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
390415 }
391416 firOpBuilder.setInsertionPointAfter (storeOp);
392- return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
393417}
394418
395419static void genReductionVars (
@@ -410,58 +434,6 @@ static void genReductionVars(
410434 }
411435}
412436
413- static llvm::SmallVector<const Fortran::semantics::Symbol *>
414- genLoopAndReductionVars (
415- mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
416- mlir::Location &loc,
417- llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
418- llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
419- llvm::ArrayRef<mlir::Type> reductionTypes) {
420- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
421-
422- llvm::SmallVector<mlir::Type> blockArgTypes;
423- llvm::SmallVector<mlir::Location> blockArgLocs;
424- blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
425- blockArgLocs.reserve (blockArgTypes.size ());
426- mlir::Block *entryBlock;
427-
428- if (loopArgs.size ()) {
429- std::size_t loopVarTypeSize = 0 ;
430- for (const Fortran::semantics::Symbol *arg : loopArgs)
431- loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
432- mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
433- std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
434- loopVarType);
435- std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
436- }
437- if (reductionArgs.size ()) {
438- llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
439- std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
440- }
441- entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
442- blockArgLocs);
443- // The argument is not currently in memory, so make a temporary for the
444- // argument, and store it there, then bind that location to the argument.
445- if (loopArgs.size ()) {
446- mlir::Operation *storeOp = nullptr ;
447- for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
448- mlir::Value indexVal =
449- fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
450- storeOp =
451- createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
452- }
453- firOpBuilder.setInsertionPointAfter (storeOp);
454- }
455- // Bind the reduction arguments to their block arguments
456- for (auto [arg, prv] : llvm::zip_equal (
457- reductionArgs,
458- llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
459- converter.bindSymbol (*arg, prv);
460- }
461-
462- return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
463- }
464-
465437static void
466438markDeclareTarget (mlir::Operation *op,
467439 Fortran::lower::AbstractConverter &converter,
@@ -1270,20 +1242,16 @@ static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
12701242static void genWsloopClauses (
12711243 Fortran::lower::AbstractConverter &converter,
12721244 Fortran::semantics::SemanticsContext &semaCtx,
1273- Fortran::lower::StatementContext &stmtCtx,
1274- Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
1245+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
12751246 mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
1276- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
12771247 llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
12781248 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
12791249 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
12801250 ClauseProcessor cp (converter, semaCtx, clauses);
1281- cp.processCollapse (loc, eval, clauseOps, iv);
12821251 cp.processNowait (clauseOps);
12831252 cp.processOrdered (clauseOps);
12841253 cp.processReduction (loc, clauseOps, &reductionTypes, &reductionSyms);
12851254 cp.processSchedule (stmtCtx, clauseOps);
1286- clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr ();
12871255 // TODO Support delayed privatization.
12881256
12891257 if (ReductionProcessor::doReductionByRef (clauseOps.reductionVars ))
@@ -1526,7 +1494,8 @@ genSimdOp(Fortran::lower::AbstractConverter &converter,
15261494 auto *nestedEval = getCollapsedLoopEval (eval, getCollapseValue (clauses));
15271495
15281496 auto ivCallback = [&](mlir::Operation *op) {
1529- return genLoopVars (op, converter, loc, iv);
1497+ genLoopVars (op, converter, loc, iv);
1498+ return iv;
15301499 };
15311500
15321501 createBodyOfOp (*loopOp,
@@ -1801,32 +1770,48 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
18011770 Fortran::semantics::SemanticsContext &semaCtx,
18021771 Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
18031772 const List<Clause> &clauses) {
1773+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
18041774 DataSharingProcessor dsp (converter, semaCtx, clauses, eval);
18051775 dsp.processStep1 ();
18061776
18071777 Fortran::lower::StatementContext stmtCtx;
1808- mlir::omp::WsloopClauseOps clauseOps;
1778+ mlir::omp::LoopNestClauseOps loopClauseOps;
1779+ mlir::omp::WsloopClauseOps wsClauseOps;
18091780 llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
18101781 llvm::SmallVector<mlir::Type> reductionTypes;
18111782 llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
1812- genWsloopClauses (converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
1813- iv, reductionTypes, reductionSyms);
1783+ genLoopNestClauses (converter, semaCtx, eval, clauses, loc, loopClauseOps, iv);
1784+ genWsloopClauses (converter, semaCtx, stmtCtx, clauses, loc, wsClauseOps,
1785+ reductionTypes, reductionSyms);
1786+
1787+ // Create omp.wsloop wrapper and populate entry block arguments with reduction
1788+ // variables.
1789+ auto wsloopOp = firOpBuilder.create <mlir::omp::WsloopOp>(loc, wsClauseOps);
1790+ llvm::SmallVector<mlir::Location> reductionLocs (reductionSyms.size (), loc);
1791+ mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock (
1792+ &wsloopOp.getRegion (), {}, reductionTypes, reductionLocs);
1793+ firOpBuilder.setInsertionPoint (
1794+ Fortran::lower::genOpenMPTerminator (firOpBuilder, wsloopOp, loc));
1795+
1796+ // Create nested omp.loop_nest and fill body with loop contents.
1797+ auto loopOp = firOpBuilder.create <mlir::omp::LoopNestOp>(loc, loopClauseOps);
18141798
18151799 auto *nestedEval = getCollapsedLoopEval (eval, getCollapseValue (clauses));
18161800
18171801 auto ivCallback = [&](mlir::Operation *op) {
1818- return genLoopAndReductionVars (op, converter, loc, iv, reductionSyms,
1819- reductionTypes);
1802+ genLoopVars (op, converter, loc, iv, reductionSyms,
1803+ wsloopEntryBlock->getArguments ());
1804+ return iv;
18201805 };
18211806
1822- return genOpWithBody<mlir::omp::WsloopOp>(
1823- OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval,
1824- llvm::omp::Directive::OMPD_do)
1825- .setClauses (&clauses)
1826- .setDataSharingProcessor (&dsp)
1827- .setReductions (&reductionSyms, &reductionTypes)
1828- .setGenRegionEntryCb (ivCallback),
1829- clauseOps) ;
1807+ createBodyOfOp (*loopOp,
1808+ OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval,
1809+ llvm::omp::Directive::OMPD_do)
1810+ .setClauses (&clauses)
1811+ .setDataSharingProcessor (&dsp)
1812+ .setReductions (&reductionSyms, &reductionTypes)
1813+ .setGenRegionEntryCb (ivCallback));
1814+ return wsloopOp ;
18301815}
18311816
18321817// ===----------------------------------------------------------------------===//
@@ -2482,8 +2467,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
24822467mlir::Operation *Fortran::lower::genOpenMPTerminator (fir::FirOpBuilder &builder,
24832468 mlir::Operation *op,
24842469 mlir::Location loc) {
2485- if (mlir::isa<mlir::omp::WsloopOp , mlir::omp::DeclareReductionOp,
2486- mlir::omp::AtomicUpdateOp, mlir::omp:: LoopNestOp>(op))
2470+ if (mlir::isa<mlir::omp::AtomicUpdateOp , mlir::omp::DeclareReductionOp,
2471+ mlir::omp::LoopNestOp>(op))
24872472 return builder.create <mlir::omp::YieldOp>(loc);
24882473 return builder.create <mlir::omp::TerminatorOp>(loc);
24892474}
0 commit comments