Skip to content

Commit a0821f0

Browse files
committed
[Flang][OpenMP] Improve entry block argument creation and binding
The main purpose of this patch is to centralize the logic for creating MLIR operation entry blocks and for binding them to the corresponding symbols. This minimizes the chances of mixing arguments up for operations having multiple entry block argument-generating clauses and prevents divergence while binding arguments. Some changes implemented to this end are: - Split into two functions the creation of the entry block, and the binding of its arguments and the corresponding Fortran symbol. This enabled a significant simplification of the lowering of composite constructs, where it's no longer necessary to manually ensure the lists of arguments and symbols refer to the same variables in the same order and also match the expected order by the `BlockArgOpenMPOpInterface`. - Removed redundant and error-prone passing of types and locations from `ClauseProcessor` methods. Instead, these are obtained from the values in the appropriate clause operands structure. This also simplifies argument lists of several lowering functions. - Access block arguments of already created MLIR operations through the `BlockArgOpenMPOpInterface` instead of directly indexing the argument list of the operation, which is not scalable as more entry block argument-generating clauses are added to an operation. - Simplified the implementation of `genParallelOp` to no longer need to define different callbacks depending on whether delayed privatization is enabled.
1 parent f58e85a commit a0821f0

File tree

7 files changed

+560
-600
lines changed

7 files changed

+560
-600
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

+22-57
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter,
166166
static void addUseDeviceClause(
167167
lower::AbstractConverter &converter, const omp::ObjectList &objects,
168168
llvm::SmallVectorImpl<mlir::Value> &operands,
169-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
170-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
171169
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
172170
genObjectList(objects, converter, operands);
173-
for (mlir::Value &operand : operands) {
171+
for (mlir::Value &operand : operands)
174172
checkMapType(operand.getLoc(), operand.getType());
175-
useDeviceTypes.push_back(operand.getType());
176-
useDeviceLocs.push_back(operand.getLoc());
177-
}
173+
178174
for (const omp::Object &object : objects)
179175
useDeviceSyms.push_back(object.sym());
180176
}
@@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
832828

833829
bool ClauseProcessor::processHasDeviceAddr(
834830
mlir::omp::HasDeviceAddrClauseOps &result,
835-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
836-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
837-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
831+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
838832
return findRepeatableClause<omp::clause::HasDeviceAddr>(
839833
[&](const omp::clause::HasDeviceAddr &devAddrClause,
840834
const parser::CharBlock &) {
841835
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
842-
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
836+
isDeviceSyms);
843837
});
844838
}
845839

@@ -864,14 +858,12 @@ bool ClauseProcessor::processIf(
864858

865859
bool ClauseProcessor::processIsDevicePtr(
866860
mlir::omp::IsDevicePtrClauseOps &result,
867-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
868-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
869-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
861+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
870862
return findRepeatableClause<omp::clause::IsDevicePtr>(
871863
[&](const omp::clause::IsDevicePtr &devPtrClause,
872864
const parser::CharBlock &) {
873865
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
874-
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
866+
isDeviceSyms);
875867
});
876868
}
877869

@@ -892,9 +884,7 @@ void ClauseProcessor::processMapObjects(
892884
std::map<const semantics::Symbol *,
893885
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
894886
llvm::SmallVectorImpl<mlir::Value> &mapVars,
895-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
896-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
897-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
887+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
898888
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
899889
for (const omp::Object &object : objects) {
900890
llvm::SmallVector<mlir::Value> bounds;
@@ -927,22 +917,15 @@ void ClauseProcessor::processMapObjects(
927917
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
928918
} else {
929919
mapVars.push_back(mapOp);
930-
if (mapSyms)
931-
mapSyms->push_back(object.sym());
932-
if (mapSymTypes)
933-
mapSymTypes->push_back(baseOp.getType());
934-
if (mapSymLocs)
935-
mapSymLocs->push_back(baseOp.getLoc());
920+
mapSyms.push_back(object.sym());
936921
}
937922
}
938923
}
939924

940925
bool ClauseProcessor::processMap(
941926
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
942927
mlir::omp::MapClauseOps &result,
943-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
944-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
945-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
928+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
946929
// We always require tracking of symbols, even if the caller does not,
947930
// so we create an optionally used local set of symbols when the mapSyms
948931
// argument is not present.
@@ -999,12 +982,11 @@ bool ClauseProcessor::processMap(
999982
}
1000983
processMapObjects(stmtCtx, clauseLocation,
1001984
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1002-
parentMemberIndices, result.mapVars, ptrMapSyms,
1003-
mapSymLocs, mapSymTypes);
985+
parentMemberIndices, result.mapVars, *ptrMapSyms);
1004986
});
1005987

1006988
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1007-
*ptrMapSyms, mapSymTypes, mapSymLocs);
989+
*ptrMapSyms);
1008990

1009991
return clauseFound;
1010992
}
@@ -1027,16 +1009,15 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10271009

10281010
processMapObjects(stmtCtx, clauseLocation, std::get<ObjectList>(clause.t),
10291011
mapTypeBits, parentMemberIndices, result.mapVars,
1030-
&mapSymbols);
1012+
mapSymbols);
10311013
};
10321014

10331015
bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
10341016
clauseFound =
10351017
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
10361018

10371019
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1038-
mapSymbols,
1039-
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
1020+
mapSymbols);
10401021
return clauseFound;
10411022
}
10421023

@@ -1054,34 +1035,24 @@ bool ClauseProcessor::processNontemporal(
10541035

10551036
bool ClauseProcessor::processReduction(
10561037
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
1057-
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
1058-
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
1038+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
10591039
return findRepeatableClause<omp::clause::Reduction>(
10601040
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
10611041
llvm::SmallVector<mlir::Value> reductionVars;
10621042
llvm::SmallVector<bool> reduceVarByRef;
10631043
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
10641044
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
10651045
ReductionProcessor rp;
1066-
rp.addDeclareReduction(
1067-
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1068-
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
1046+
rp.addDeclareReduction(currentLocation, converter, clause,
1047+
reductionVars, reduceVarByRef,
1048+
reductionDeclSymbols, reductionSyms);
10691049

10701050
// Copy local lists into the output.
10711051
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
10721052
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
10731053
llvm::copy(reductionDeclSymbols,
10741054
std::back_inserter(result.reductionSyms));
1075-
1076-
if (outReductionTypes) {
1077-
outReductionTypes->reserve(outReductionTypes->size() +
1078-
reductionVars.size());
1079-
llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
1080-
[](mlir::Value v) { return v.getType(); });
1081-
}
1082-
1083-
if (outReductionSyms)
1084-
llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
1055+
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
10851056
});
10861057
}
10871058

@@ -1107,8 +1078,6 @@ bool ClauseProcessor::processEnter(
11071078

11081079
bool ClauseProcessor::processUseDeviceAddr(
11091080
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
1110-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1111-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11121081
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
11131082
std::map<const semantics::Symbol *,
11141083
llvm::SmallVector<OmpMapMemberIndicesData>>
@@ -1122,19 +1091,16 @@ bool ClauseProcessor::processUseDeviceAddr(
11221091
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
11231092
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
11241093
parentMemberIndices, result.useDeviceAddrVars,
1125-
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
1094+
useDeviceSyms);
11261095
});
11271096

11281097
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1129-
result.useDeviceAddrVars, useDeviceSyms,
1130-
&useDeviceTypes, &useDeviceLocs);
1098+
result.useDeviceAddrVars, useDeviceSyms);
11311099
return clauseFound;
11321100
}
11331101

11341102
bool ClauseProcessor::processUseDevicePtr(
11351103
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
1136-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1137-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11381104
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
11391105
std::map<const semantics::Symbol *,
11401106
llvm::SmallVector<OmpMapMemberIndicesData>>
@@ -1148,12 +1114,11 @@ bool ClauseProcessor::processUseDevicePtr(
11481114
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
11491115
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
11501116
parentMemberIndices, result.useDevicePtrVars,
1151-
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
1117+
useDeviceSyms);
11521118
});
11531119

11541120
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1155-
result.useDevicePtrVars, useDeviceSyms,
1156-
&useDeviceTypes, &useDeviceLocs);
1121+
result.useDevicePtrVars, useDeviceSyms);
11571122
return clauseFound;
11581123
}
11591124

flang/lib/Lower/OpenMP/ClauseProcessor.h

+12-26
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ class ClauseProcessor {
6868
mlir::omp::FinalClauseOps &result) const;
6969
bool processHasDeviceAddr(
7070
mlir::omp::HasDeviceAddrClauseOps &result,
71-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
72-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
73-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
71+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
7472
bool processHint(mlir::omp::HintClauseOps &result) const;
7573
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
7674
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
@@ -104,43 +102,33 @@ class ClauseProcessor {
104102
mlir::omp::IfClauseOps &result) const;
105103
bool processIsDevicePtr(
106104
mlir::omp::IsDevicePtrClauseOps &result,
107-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
108-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
109-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
105+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
110106
bool
111107
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
112108

113109
// This method is used to process a map clause.
114-
// The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
115-
// store the original type, location and Fortran symbol for the map operands.
116-
// They may be used later on to create the block_arguments for some of the
117-
// target directives that require it.
118-
bool processMap(
119-
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
120-
mlir::omp::MapClauseOps &result,
121-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
122-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
123-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
110+
// The optional parameter mapSyms is used to store the original Fortran symbol
111+
// for the map operands. It may be used later on to create the block_arguments
112+
// for some of the directives that require it.
113+
bool processMap(mlir::Location currentLocation,
114+
lower::StatementContext &stmtCtx,
115+
mlir::omp::MapClauseOps &result,
116+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
117+
nullptr) const;
124118
bool processMotionClauses(lower::StatementContext &stmtCtx,
125119
mlir::omp::MapClauseOps &result);
126120
bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
127121
bool processReduction(
128122
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
129-
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
130-
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
131-
nullptr) const;
123+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
132124
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
133125
bool processUseDeviceAddr(
134126
lower::StatementContext &stmtCtx,
135127
mlir::omp::UseDeviceAddrClauseOps &result,
136-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
137-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
138128
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
139129
bool processUseDevicePtr(
140130
lower::StatementContext &stmtCtx,
141131
mlir::omp::UseDevicePtrClauseOps &result,
142-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
143-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
144132
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
145133

146134
// Call this method for these clauses that should be supported but are not
@@ -181,9 +169,7 @@ class ClauseProcessor {
181169
std::map<const semantics::Symbol *,
182170
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
183171
llvm::SmallVectorImpl<mlir::Value> &mapVars,
184-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
185-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
186-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
172+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
187173

188174
lower::AbstractConverter &converter;
189175
semantics::SemanticsContext &semaCtx;

0 commit comments

Comments
 (0)