@@ -456,6 +456,33 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
456
456
declareTargetOp.setDeclareTarget (deviceType, captureClause);
457
457
}
458
458
459
+ // / For an operation that takes `omp.private` values as region args, this util
460
+ // / merges the private vars info into the region arguments list.
461
+ // /
462
+ // / \tparam OMPOP - the OpenMP op that takes `omp.private` inputs.
463
+ // / \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type
464
+ // / or mlir::Location fields of the private var list.
465
+ // /
466
+ // / \param [in] op - the op accepting `omp.private` inputs.
467
+ // / \param [in] currentList - the current list of region info that we
468
+ // / want to merge private info with. For example this could be the list of types
469
+ // / or locations of previous arguments to \op's region.
470
+ // / \param [in] infoAccessor - for a private variable, this returns the
471
+ // / data we want to merge: type or location.
472
+ // / \param [out] allRegionArgsInfo - the merged list of region info.
473
+ template <typename OMPOp, typename InfoTy>
474
+ static void
475
+ mergePrivateVarsInfo (OMPOp op, llvm::ArrayRef<InfoTy> currentList,
476
+ llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
477
+ llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
478
+ mlir::OperandRange privateVars = op.getPrivateVars ();
479
+
480
+ llvm::transform (currentList, std::back_inserter (allRegionArgsInfo),
481
+ [](InfoTy i) { return i; });
482
+ llvm::transform (privateVars, std::back_inserter (allRegionArgsInfo),
483
+ infoAccessor);
484
+ }
485
+
459
486
// ===----------------------------------------------------------------------===//
460
487
// Op body generation helper structures and functions
461
488
// ===----------------------------------------------------------------------===//
@@ -758,15 +785,28 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
758
785
llvm::ArrayRef<const semantics::Symbol *> mapSyms,
759
786
llvm::ArrayRef<mlir::Location> mapSymLocs,
760
787
llvm::ArrayRef<mlir::Type> mapSymTypes,
788
+ DataSharingProcessor &dsp,
761
789
const mlir::Location ¤tLocation,
762
790
const ConstructQueue &queue, ConstructQueue::iterator item) {
763
791
assert (mapSymTypes.size () == mapSymLocs.size ());
764
792
765
793
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
766
794
mlir::Region ®ion = targetOp.getRegion ();
767
795
768
- auto *regionBlock =
769
- firOpBuilder.createBlock (®ion, {}, mapSymTypes, mapSymLocs);
796
+ llvm::SmallVector<mlir::Type> allRegionArgTypes;
797
+ mergePrivateVarsInfo (targetOp, mapSymTypes,
798
+ llvm::function_ref<mlir::Type (mlir::Value)>{
799
+ [](mlir::Value v) { return v.getType (); }},
800
+ allRegionArgTypes);
801
+
802
+ llvm::SmallVector<mlir::Location> allRegionArgLocs;
803
+ mergePrivateVarsInfo (targetOp, mapSymLocs,
804
+ llvm::function_ref<mlir::Location (mlir::Value)>{
805
+ [](mlir::Value v) { return v.getLoc (); }},
806
+ allRegionArgLocs);
807
+
808
+ auto *regionBlock = firOpBuilder.createBlock (®ion, {}, allRegionArgTypes,
809
+ allRegionArgLocs);
770
810
771
811
// Clones the `bounds` placing them inside the target region and returns them.
772
812
auto cloneBound = [&](mlir::Value bound) {
@@ -830,6 +870,20 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
830
870
});
831
871
}
832
872
873
+ for (auto [argIndex, argSymbol] :
874
+ llvm::enumerate (dsp.getAllSymbolsToPrivatize ())) {
875
+ argIndex = mapSyms.size () + argIndex;
876
+
877
+ const mlir::BlockArgument &arg = region.getArgument (argIndex);
878
+ converter.bindSymbol (*argSymbol,
879
+ hlfir::translateToExtendedValue (
880
+ currentLocation, firOpBuilder, hlfir::Entity{arg},
881
+ /* contiguousHint=*/
882
+ evaluate::IsSimplyContiguous (
883
+ *argSymbol, converter.getFoldingContext ()))
884
+ .first );
885
+ }
886
+
833
887
// Check if cloning the bounds introduced any dependency on the outer region.
834
888
// If so, then either clone them as well if they are MemoryEffectFree, or else
835
889
// copy them to a new temporary and add them to the map and block_argument
@@ -907,6 +961,8 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
907
961
} else {
908
962
genNestedEvaluations (converter, eval);
909
963
}
964
+
965
+ dsp.processStep2 (targetOp, /* isLoop=*/ false );
910
966
}
911
967
912
968
template <typename OpTy, typename ... Args>
@@ -1048,15 +1104,18 @@ static void genTargetClauses(
1048
1104
devicePtrSyms);
1049
1105
cp.processMap (loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes);
1050
1106
cp.processThreadLimit (stmtCtx, clauseOps);
1051
- // TODO Support delayed privatization.
1052
1107
1053
1108
if (processHostOnlyClauses)
1054
1109
cp.processNowait (clauseOps);
1055
1110
1056
1111
cp.processTODO <clause::Allocate, clause::Defaultmap, clause::Firstprivate,
1057
- clause::InReduction, clause::Private, clause:: Reduction,
1112
+ clause::InReduction, clause::Reduction,
1058
1113
clause::UsesAllocators>(loc,
1059
1114
llvm::omp::Directive::OMPD_target);
1115
+
1116
+ // `target private(..)` is only supported in delayed privatization mode.
1117
+ if (!enableDelayedPrivatization)
1118
+ cp.processTODO <clause::Private>(loc, llvm::omp::Directive::OMPD_target);
1060
1119
}
1061
1120
1062
1121
static void genTargetDataClauses (
@@ -1289,7 +1348,6 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1289
1348
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1290
1349
lower::StatementContext stmtCtx;
1291
1350
mlir::omp::ParallelClauseOps clauseOps;
1292
- llvm::SmallVector<const semantics::Symbol *> privateSyms;
1293
1351
llvm::SmallVector<mlir::Type> reductionTypes;
1294
1352
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
1295
1353
genParallelClauses (converter, semaCtx, stmtCtx, item->clauses , loc,
@@ -1319,34 +1377,35 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1319
1377
/* useDelayedPrivatization=*/ true , &symTable);
1320
1378
1321
1379
if (privatize)
1322
- dsp.processStep1 (&clauseOps, &privateSyms );
1380
+ dsp.processStep1 (&clauseOps);
1323
1381
1324
1382
auto genRegionEntryCB = [&](mlir::Operation *op) {
1325
1383
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
1326
1384
1327
1385
llvm::SmallVector<mlir::Location> reductionLocs (
1328
1386
clauseOps.reductionVars .size (), loc);
1329
1387
1330
- mlir::OperandRange privateVars = parallelOp.getPrivateVars ();
1331
- mlir::Region ®ion = parallelOp.getRegion ();
1388
+ llvm::SmallVector<mlir::Type> allRegionArgTypes;
1389
+ mergePrivateVarsInfo (parallelOp, llvm::ArrayRef (reductionTypes),
1390
+ llvm::function_ref<mlir::Type (mlir::Value)>{
1391
+ [](mlir::Value v) { return v.getType (); }},
1392
+ allRegionArgTypes);
1332
1393
1333
- llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
1334
- privateVarTypes.reserve (privateVarTypes.size () + privateVars.size ());
1335
- llvm::transform (privateVars, std::back_inserter (privateVarTypes),
1336
- [](mlir::Value v) { return v.getType (); });
1394
+ llvm::SmallVector<mlir::Location> allRegionArgLocs;
1395
+ mergePrivateVarsInfo (parallelOp, llvm::ArrayRef (reductionLocs),
1396
+ llvm::function_ref<mlir::Location (mlir::Value)>{
1397
+ [](mlir::Value v) { return v.getLoc (); }},
1398
+ allRegionArgLocs);
1337
1399
1338
- llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
1339
- privateVarLocs.reserve (privateVarLocs.size () + privateVars.size ());
1340
- llvm::transform (privateVars, std::back_inserter (privateVarLocs),
1341
- [](mlir::Value v) { return v.getLoc (); });
1342
-
1343
- firOpBuilder.createBlock (®ion, /* insertPt=*/ {}, privateVarTypes,
1344
- privateVarLocs);
1400
+ mlir::Region ®ion = parallelOp.getRegion ();
1401
+ firOpBuilder.createBlock (®ion, /* insertPt=*/ {}, allRegionArgTypes,
1402
+ allRegionArgLocs);
1345
1403
1346
1404
llvm::SmallVector<const semantics::Symbol *> allSymbols = reductionSyms;
1347
- allSymbols.append (privateSyms);
1405
+ allSymbols.append (dsp.getAllSymbolsToPrivatize ().begin (),
1406
+ dsp.getAllSymbolsToPrivatize ().end ());
1407
+
1348
1408
for (auto [arg, prv] : llvm::zip_equal (allSymbols, region.getArguments ())) {
1349
- fir::ExtendedValue hostExV = converter.getSymbolExtendedValue (*arg);
1350
1409
converter.bindSymbol (*arg, hlfir::translateToExtendedValue (
1351
1410
loc, firOpBuilder, hlfir::Entity{prv},
1352
1411
/* contiguousHint=*/
@@ -1541,11 +1600,22 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1541
1600
deviceAddrLocs, deviceAddrTypes, devicePtrSyms,
1542
1601
devicePtrLocs, devicePtrTypes);
1543
1602
1603
+ llvm::SmallVector<const semantics::Symbol *> privateSyms;
1604
+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1605
+ /* shouldCollectPreDeterminedSymbols=*/
1606
+ lower::omp::isLastItemInQueue (item, queue),
1607
+ /* useDelayedPrivatization=*/ true , &symTable);
1608
+ dsp.processStep1 (&clauseOps);
1609
+
1544
1610
// 5.8.1 Implicit Data-Mapping Attribute Rules
1545
1611
// The following code follows the implicit data-mapping rules to map all the
1546
- // symbols used inside the region that have not been explicitly mapped using
1547
- // the map clause.
1612
+ // symbols used inside the region that do not have explicit data-environment
1613
+ // attribute clauses (neither data-sharing; e.g. `private`, nor `map`
1614
+ // clauses).
1548
1615
auto captureImplicitMap = [&](const semantics::Symbol &sym) {
1616
+ if (dsp.getAllSymbolsToPrivatize ().contains (&sym))
1617
+ return ;
1618
+
1549
1619
if (llvm::find (mapSyms, &sym) == mapSyms.end ()) {
1550
1620
mlir::Value baseOp = converter.getSymbolAddress (sym);
1551
1621
if (!baseOp)
@@ -1632,7 +1702,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1632
1702
1633
1703
auto targetOp = firOpBuilder.create <mlir::omp::TargetOp>(loc, clauseOps);
1634
1704
genBodyOfTargetOp (converter, symTable, semaCtx, eval, targetOp, mapSyms,
1635
- mapLocs, mapTypes, loc, queue, item);
1705
+ mapLocs, mapTypes, dsp, loc, queue, item);
1636
1706
return targetOp;
1637
1707
}
1638
1708
0 commit comments