Skip to content

Commit 913a824

Browse files
authored
[flang][OpenMP] Lower target .. private(..) to omp.private ops (#94195)
Extends delayed privatization support to `taraget .. private(..)`. With this PR, `private` is support for `target` **only** is delayed privatization mode.
1 parent acc927a commit 913a824

File tree

6 files changed

+396
-47
lines changed

6 files changed

+396
-47
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,13 @@ DataSharingProcessor::DataSharingProcessor(
4848
}
4949

5050
void DataSharingProcessor::processStep1(
51-
mlir::omp::PrivateClauseOps *clauseOps,
52-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
51+
mlir::omp::PrivateClauseOps *clauseOps) {
5352
collectSymbolsForPrivatization();
5453
collectDefaultSymbols();
5554
collectImplicitSymbols();
5655
collectPreDeterminedSymbols();
5756

58-
privatize(clauseOps, privateSyms);
57+
privatize(clauseOps);
5958

6059
insertBarrier();
6160
}
@@ -415,16 +414,14 @@ void DataSharingProcessor::collectPreDeterminedSymbols() {
415414
preDeterminedSymbols);
416415
}
417416

418-
void DataSharingProcessor::privatize(
419-
mlir::omp::PrivateClauseOps *clauseOps,
420-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
417+
void DataSharingProcessor::privatize(mlir::omp::PrivateClauseOps *clauseOps) {
421418
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
422419
if (const auto *commonDet =
423420
sym->detailsIf<semantics::CommonBlockDetails>()) {
424421
for (const auto &mem : commonDet->objects())
425-
doPrivatize(&*mem, clauseOps, privateSyms);
422+
doPrivatize(&*mem, clauseOps);
426423
} else
427-
doPrivatize(sym, clauseOps, privateSyms);
424+
doPrivatize(sym, clauseOps);
428425
}
429426
}
430427

@@ -441,9 +438,8 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
441438
}
442439
}
443440

444-
void DataSharingProcessor::doPrivatize(
445-
const semantics::Symbol *sym, mlir::omp::PrivateClauseOps *clauseOps,
446-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
441+
void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
442+
mlir::omp::PrivateClauseOps *clauseOps) {
447443
if (!useDelayedPrivatization) {
448444
cloneSymbol(sym);
449445
copyFirstPrivateSymbol(sym);
@@ -548,9 +544,6 @@ void DataSharingProcessor::doPrivatize(
548544
clauseOps->privateVars.push_back(hsb.getAddr());
549545
}
550546

551-
if (privateSyms)
552-
privateSyms->push_back(sym);
553-
554547
symToPrivatizer[sym] = privatizerOp;
555548
}
556549

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,15 @@ class DataSharingProcessor {
105105
void collectDefaultSymbols();
106106
void collectImplicitSymbols();
107107
void collectPreDeterminedSymbols();
108-
void privatize(mlir::omp::PrivateClauseOps *clauseOps,
109-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms);
108+
void privatize(mlir::omp::PrivateClauseOps *clauseOps);
110109
void defaultPrivatize(
111110
mlir::omp::PrivateClauseOps *clauseOps,
112111
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms);
113112
void implicitPrivatize(
114113
mlir::omp::PrivateClauseOps *clauseOps,
115114
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms);
116-
void
117-
doPrivatize(const semantics::Symbol *sym,
118-
mlir::omp::PrivateClauseOps *clauseOps,
119-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms);
115+
void doPrivatize(const semantics::Symbol *sym,
116+
mlir::omp::PrivateClauseOps *clauseOps);
120117
void copyLastPrivatize(mlir::Operation *op);
121118
void insertLastPrivateCompare(mlir::Operation *op);
122119
void cloneSymbol(const semantics::Symbol *sym);
@@ -147,15 +144,18 @@ class DataSharingProcessor {
147144
// Step2 performs the copying for lastprivates and requires knowledge of the
148145
// MLIR operation to insert the last private update. Step2 adds
149146
// dealocation code as well.
150-
void processStep1(
151-
mlir::omp::PrivateClauseOps *clauseOps = nullptr,
152-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms = nullptr);
147+
void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr);
153148
void processStep2(mlir::Operation *op, bool isLoop);
154149

155150
void setLoopIV(mlir::Value iv) {
156151
assert(!loopIV && "Loop iteration variable already set");
157152
loopIV = iv;
158153
}
154+
155+
const llvm::SetVector<const semantics::Symbol *> &
156+
getAllSymbolsToPrivatize() const {
157+
return allPrivatizedSymbols;
158+
}
159159
};
160160

161161
} // namespace omp

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,33 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
456456
declareTargetOp.setDeclareTarget(deviceType, captureClause);
457457
}
458458

459+
/// For an operation that takes `omp.private` values as region args, this util
460+
/// merges the private vars info into the region arguments list.
461+
///
462+
/// \tparam OMPOP - the OpenMP op that takes `omp.private` inputs.
463+
/// \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type
464+
/// or mlir::Location fields of the private var list.
465+
///
466+
/// \param [in] op - the op accepting `omp.private` inputs.
467+
/// \param [in] currentList - the current list of region info that we
468+
/// want to merge private info with. For example this could be the list of types
469+
/// or locations of previous arguments to \op's region.
470+
/// \param [in] infoAccessor - for a private variable, this returns the
471+
/// data we want to merge: type or location.
472+
/// \param [out] allRegionArgsInfo - the merged list of region info.
473+
template <typename OMPOp, typename InfoTy>
474+
static void
475+
mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
476+
llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
477+
llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
478+
mlir::OperandRange privateVars = op.getPrivateVars();
479+
480+
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
481+
[](InfoTy i) { return i; });
482+
llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
483+
infoAccessor);
484+
}
485+
459486
//===----------------------------------------------------------------------===//
460487
// Op body generation helper structures and functions
461488
//===----------------------------------------------------------------------===//
@@ -758,15 +785,28 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
758785
llvm::ArrayRef<const semantics::Symbol *> mapSyms,
759786
llvm::ArrayRef<mlir::Location> mapSymLocs,
760787
llvm::ArrayRef<mlir::Type> mapSymTypes,
788+
DataSharingProcessor &dsp,
761789
const mlir::Location &currentLocation,
762790
const ConstructQueue &queue, ConstructQueue::iterator item) {
763791
assert(mapSymTypes.size() == mapSymLocs.size());
764792

765793
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
766794
mlir::Region &region = targetOp.getRegion();
767795

768-
auto *regionBlock =
769-
firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
796+
llvm::SmallVector<mlir::Type> allRegionArgTypes;
797+
mergePrivateVarsInfo(targetOp, mapSymTypes,
798+
llvm::function_ref<mlir::Type(mlir::Value)>{
799+
[](mlir::Value v) { return v.getType(); }},
800+
allRegionArgTypes);
801+
802+
llvm::SmallVector<mlir::Location> allRegionArgLocs;
803+
mergePrivateVarsInfo(targetOp, mapSymLocs,
804+
llvm::function_ref<mlir::Location(mlir::Value)>{
805+
[](mlir::Value v) { return v.getLoc(); }},
806+
allRegionArgLocs);
807+
808+
auto *regionBlock = firOpBuilder.createBlock(&region, {}, allRegionArgTypes,
809+
allRegionArgLocs);
770810

771811
// Clones the `bounds` placing them inside the target region and returns them.
772812
auto cloneBound = [&](mlir::Value bound) {
@@ -830,6 +870,20 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
830870
});
831871
}
832872

873+
for (auto [argIndex, argSymbol] :
874+
llvm::enumerate(dsp.getAllSymbolsToPrivatize())) {
875+
argIndex = mapSyms.size() + argIndex;
876+
877+
const mlir::BlockArgument &arg = region.getArgument(argIndex);
878+
converter.bindSymbol(*argSymbol,
879+
hlfir::translateToExtendedValue(
880+
currentLocation, firOpBuilder, hlfir::Entity{arg},
881+
/*contiguousHint=*/
882+
evaluate::IsSimplyContiguous(
883+
*argSymbol, converter.getFoldingContext()))
884+
.first);
885+
}
886+
833887
// Check if cloning the bounds introduced any dependency on the outer region.
834888
// If so, then either clone them as well if they are MemoryEffectFree, or else
835889
// copy them to a new temporary and add them to the map and block_argument
@@ -907,6 +961,8 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
907961
} else {
908962
genNestedEvaluations(converter, eval);
909963
}
964+
965+
dsp.processStep2(targetOp, /*isLoop=*/false);
910966
}
911967

912968
template <typename OpTy, typename... Args>
@@ -1048,15 +1104,18 @@ static void genTargetClauses(
10481104
devicePtrSyms);
10491105
cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes);
10501106
cp.processThreadLimit(stmtCtx, clauseOps);
1051-
// TODO Support delayed privatization.
10521107

10531108
if (processHostOnlyClauses)
10541109
cp.processNowait(clauseOps);
10551110

10561111
cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate,
1057-
clause::InReduction, clause::Private, clause::Reduction,
1112+
clause::InReduction, clause::Reduction,
10581113
clause::UsesAllocators>(loc,
10591114
llvm::omp::Directive::OMPD_target);
1115+
1116+
// `target private(..)` is only supported in delayed privatization mode.
1117+
if (!enableDelayedPrivatization)
1118+
cp.processTODO<clause::Private>(loc, llvm::omp::Directive::OMPD_target);
10601119
}
10611120

10621121
static void genTargetDataClauses(
@@ -1289,7 +1348,6 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
12891348
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
12901349
lower::StatementContext stmtCtx;
12911350
mlir::omp::ParallelClauseOps clauseOps;
1292-
llvm::SmallVector<const semantics::Symbol *> privateSyms;
12931351
llvm::SmallVector<mlir::Type> reductionTypes;
12941352
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
12951353
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
@@ -1319,34 +1377,35 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
13191377
/*useDelayedPrivatization=*/true, &symTable);
13201378

13211379
if (privatize)
1322-
dsp.processStep1(&clauseOps, &privateSyms);
1380+
dsp.processStep1(&clauseOps);
13231381

13241382
auto genRegionEntryCB = [&](mlir::Operation *op) {
13251383
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
13261384

13271385
llvm::SmallVector<mlir::Location> reductionLocs(
13281386
clauseOps.reductionVars.size(), loc);
13291387

1330-
mlir::OperandRange privateVars = parallelOp.getPrivateVars();
1331-
mlir::Region &region = parallelOp.getRegion();
1388+
llvm::SmallVector<mlir::Type> allRegionArgTypes;
1389+
mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionTypes),
1390+
llvm::function_ref<mlir::Type(mlir::Value)>{
1391+
[](mlir::Value v) { return v.getType(); }},
1392+
allRegionArgTypes);
13321393

1333-
llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
1334-
privateVarTypes.reserve(privateVarTypes.size() + privateVars.size());
1335-
llvm::transform(privateVars, std::back_inserter(privateVarTypes),
1336-
[](mlir::Value v) { return v.getType(); });
1394+
llvm::SmallVector<mlir::Location> allRegionArgLocs;
1395+
mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
1396+
llvm::function_ref<mlir::Location(mlir::Value)>{
1397+
[](mlir::Value v) { return v.getLoc(); }},
1398+
allRegionArgLocs);
13371399

1338-
llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
1339-
privateVarLocs.reserve(privateVarLocs.size() + privateVars.size());
1340-
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
1341-
[](mlir::Value v) { return v.getLoc(); });
1342-
1343-
firOpBuilder.createBlock(&region, /*insertPt=*/{}, privateVarTypes,
1344-
privateVarLocs);
1400+
mlir::Region &region = parallelOp.getRegion();
1401+
firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
1402+
allRegionArgLocs);
13451403

13461404
llvm::SmallVector<const semantics::Symbol *> allSymbols = reductionSyms;
1347-
allSymbols.append(privateSyms);
1405+
allSymbols.append(dsp.getAllSymbolsToPrivatize().begin(),
1406+
dsp.getAllSymbolsToPrivatize().end());
1407+
13481408
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
1349-
fir::ExtendedValue hostExV = converter.getSymbolExtendedValue(*arg);
13501409
converter.bindSymbol(*arg, hlfir::translateToExtendedValue(
13511410
loc, firOpBuilder, hlfir::Entity{prv},
13521411
/*contiguousHint=*/
@@ -1541,11 +1600,22 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
15411600
deviceAddrLocs, deviceAddrTypes, devicePtrSyms,
15421601
devicePtrLocs, devicePtrTypes);
15431602

1603+
llvm::SmallVector<const semantics::Symbol *> privateSyms;
1604+
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
1605+
/*shouldCollectPreDeterminedSymbols=*/
1606+
lower::omp::isLastItemInQueue(item, queue),
1607+
/*useDelayedPrivatization=*/true, &symTable);
1608+
dsp.processStep1(&clauseOps);
1609+
15441610
// 5.8.1 Implicit Data-Mapping Attribute Rules
15451611
// The following code follows the implicit data-mapping rules to map all the
1546-
// symbols used inside the region that have not been explicitly mapped using
1547-
// the map clause.
1612+
// symbols used inside the region that do not have explicit data-environment
1613+
// attribute clauses (neither data-sharing; e.g. `private`, nor `map`
1614+
// clauses).
15481615
auto captureImplicitMap = [&](const semantics::Symbol &sym) {
1616+
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
1617+
return;
1618+
15491619
if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
15501620
mlir::Value baseOp = converter.getSymbolAddress(sym);
15511621
if (!baseOp)
@@ -1632,7 +1702,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16321702

16331703
auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps);
16341704
genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, mapSyms,
1635-
mapLocs, mapTypes, loc, queue, item);
1705+
mapLocs, mapTypes, dsp, loc, queue, item);
16361706
return targetOp;
16371707
}
16381708

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
! Tests delayed privatization for `targets ... private(..)` for allocatables.
2+
3+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -mmlir --openmp-enable-delayed-privatization \
4+
! RUN: -o - %s 2>&1 | FileCheck %s
5+
! RUN: bbc -emit-hlfir -fopenmp --openmp-enable-delayed-privatization -o - %s 2>&1 \
6+
! RUN: | FileCheck %s
7+
8+
subroutine target_allocatable
9+
implicit none
10+
integer, allocatable :: alloc_var
11+
12+
!$omp target private(alloc_var)
13+
alloc_var = 10
14+
!$omp end target
15+
end subroutine target_allocatable
16+
17+
! CHECK-LABEL: omp.private {type = private}
18+
! CHECK-SAME: @[[VAR_PRIVATIZER_SYM:.*]] :
19+
! CHECK-SAME: [[TYPE:!fir.ref<!fir.box<!fir.heap<i32>>>]] alloc {
20+
! CHECK: ^bb0(%[[PRIV_ARG:.*]]: [[TYPE]]):
21+
! CHECK: %[[PRIV_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "alloc_var", {{.*}}}
22+
23+
! CHECK-NEXT: %[[PRIV_ARG_VAL:.*]] = fir.load %[[PRIV_ARG]] : !fir.ref<!fir.box<!fir.heap<i32>>>
24+
! CHECK-NEXT: %[[PRIV_ARG_BOX:.*]] = fir.box_addr %[[PRIV_ARG_VAL]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
25+
! CHECK-NEXT: %[[PRIV_ARG_ADDR:.*]] = fir.convert %[[PRIV_ARG_BOX]] : (!fir.heap<i32>) -> i64
26+
! CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i64
27+
! CHECK-NEXT: %[[ALLOC_COND:.*]] = arith.cmpi ne, %[[PRIV_ARG_ADDR]], %[[C0]] : i64
28+
29+
! CHECK-NEXT: fir.if %[[ALLOC_COND]] {
30+
! CHECK: %[[PRIV_ALLOCMEM:.*]] = fir.allocmem i32 {fir.must_be_heap = true, {{.*}}}
31+
! CHECK-NEXT: %[[PRIV_ALLOCMEM_BOX:.*]] = fir.embox %[[PRIV_ALLOCMEM]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
32+
! CHECK-NEXT: fir.store %[[PRIV_ALLOCMEM_BOX]] to %[[PRIV_ALLOC]] : !fir.ref<!fir.box<!fir.heap<i32>>>
33+
! CHECK-NEXT: } else {
34+
! CHECK-NEXT: %[[ZERO_BITS:.*]] = fir.zero_bits !fir.heap<i32>
35+
! CHECK-NEXT: %[[ZERO_BOX:.*]] = fir.embox %[[ZERO_BITS]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
36+
! CHECK-NEXT: fir.store %[[ZERO_BOX]] to %[[PRIV_ALLOC]] : !fir.ref<!fir.box<!fir.heap<i32>>>
37+
! CHECK-NEXT: }
38+
39+
! CHECK-NEXT: %[[PRIV_DECL:.*]]:2 = hlfir.declare %[[PRIV_ALLOC]]
40+
! CHECK-NEXT: omp.yield(%[[PRIV_DECL]]#0 : [[TYPE]])
41+
42+
! CHECK-NEXT: } dealloc {
43+
! CHECK-NEXT: ^bb0(%[[PRIV_ARG:.*]]: [[TYPE]]):
44+
45+
! CHECK-NEXT: %[[PRIV_VAL:.*]] = fir.load %[[PRIV_ARG]]
46+
! CHECK-NEXT: %[[PRIV_ADDR:.*]] = fir.box_addr %[[PRIV_VAL]]
47+
! CHECK-NEXT: %[[PRIV_ADDR_I64:.*]] = fir.convert %[[PRIV_ADDR]]
48+
! CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i64
49+
! CHECK-NEXT: %[[PRIV_NULL_COND:.*]] = arith.cmpi ne, %[[PRIV_ADDR_I64]], %[[C0]] : i64
50+
51+
! CHECK-NEXT: fir.if %[[PRIV_NULL_COND]] {
52+
! CHECK: %[[PRIV_VAL_2:.*]] = fir.load %[[PRIV_ARG]]
53+
! CHECK-NEXT: %[[PRIV_ADDR_2:.*]] = fir.box_addr %[[PRIV_VAL_2]]
54+
! CHECK-NEXT: fir.freemem %[[PRIV_ADDR_2]]
55+
! CHECK-NEXT: %[[ZEROS:.*]] = fir.zero_bits
56+
! CHECK-NEXT: %[[ZEROS_BOX:.*]] = fir.embox %[[ZEROS]]
57+
! CHECK-NEXT: fir.store %[[ZEROS_BOX]] to %[[PRIV_ARG]]
58+
! CHECK-NEXT: }
59+
60+
! CHECK-NEXT: omp.yield
61+
! CHECK-NEXT: }
62+
63+
64+
! CHECK-LABEL: func.func @_QPtarget_allocatable() {
65+
66+
! CHECK: %[[VAR_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<i32>>
67+
! CHECK-SAME: {bindc_name = "alloc_var", {{.*}}}
68+
! CHECK: %[[VAR_DECL:.*]]:2 = hlfir.declare %[[VAR_ALLOC]]
69+
70+
! CHECK: omp.target private(
71+
! CHECK-SAME: @[[VAR_PRIVATIZER_SYM]] %[[VAR_DECL]]#0 -> %{{.*}} : [[TYPE]]) {

0 commit comments

Comments
 (0)