Skip to content

Commit 2af3c27

Browse files
authored
Merge pull request #1063 from schweitzpgi/ch-forall
Fix evaluation semantics of FORALL constructs per 10.2.4.2.4.
2 parents bb1ab87 + 126aa69 commit 2af3c27

File tree

9 files changed

+944
-664
lines changed

9 files changed

+944
-664
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ createSomeArrayTempValue(AbstractConverter &converter,
176176
fir::ExtendedValue
177177
createLazyArrayTempValue(AbstractConverter &converter,
178178
const evaluate::Expr<evaluate::SomeType> &expr,
179-
mlir::Value var, SymMap &symMap,
180-
StatementContext &stmtCtx);
179+
mlir::Value var, mlir::Value shapeBuffer,
180+
SymMap &symMap, StatementContext &stmtCtx);
181181

182182
/// Lower an array expression to a value of type box. The expression must be a
183183
/// variable.

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ class FirOpBuilder : public mlir::OpBuilder {
248248
}
249249

250250
/// Construct one of the two forms of shape op from an array box.
251-
mlir::Value consShape(mlir::Location loc, const fir::AbstractArrayBox &arr);
252-
mlir::Value consShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> shift,
251+
mlir::Value genShape(mlir::Location loc, const fir::AbstractArrayBox &arr);
252+
mlir::Value genShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> shift,
253253
llvm::ArrayRef<mlir::Value> exts);
254-
mlir::Value consShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> exts);
254+
mlir::Value genShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> exts);
255255

256256
/// Create one of the shape ops given an extended value. For a boxed value,
257257
/// this may create a `fir.shift` op.

flang/lib/Lower/Bridge.cpp

Lines changed: 115 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,80 +1184,123 @@ class FirConverter : public Fortran::lower::AbstractConverter {
11841184
/// Process a concurrent header for a FORALL. (Concurrent headers for DO
11851185
/// CONCURRENT loops are lowered elsewhere.)
11861186
void genFIR(const Fortran::parser::ConcurrentHeader &header) {
1187-
// Create our iteration space from the header spec.
1188-
localSymbols.pushScope();
1189-
auto idxTy = builder->getIndexType();
1190-
auto loc = toLocation();
1191-
llvm::SmallVector<fir::DoLoopOp> loops;
1192-
for (auto &ctrl :
1193-
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1194-
const auto *ctrlVar = std::get<Fortran::parser::Name>(ctrl.t).symbol;
1195-
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1196-
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1197-
auto &optStep =
1198-
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1199-
auto lb = builder->createConvert(
1200-
loc, idxTy,
1201-
fir::getBase(genExprValue(*lo, explicitIterSpace.stmtContext())));
1202-
auto ub = builder->createConvert(
1203-
loc, idxTy,
1204-
fir::getBase(genExprValue(*hi, explicitIterSpace.stmtContext())));
1205-
auto by = optStep.has_value()
1206-
? builder->createConvert(
1207-
loc, idxTy,
1208-
fir::getBase(genExprValue(
1209-
*Fortran::semantics::GetExpr(*optStep),
1210-
explicitIterSpace.stmtContext())))
1211-
: builder->createIntegerConstant(loc, idxTy, 1);
1212-
auto lp = builder->create<fir::DoLoopOp>(
1213-
loc, lb, ub, by, /*unordered=*/true,
1214-
/*finalCount=*/false, explicitIterSpace.getInnerArgs());
1215-
if (!loops.empty())
1216-
builder->create<fir::ResultOp>(loc, lp.getResults());
1217-
explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
1218-
builder->setInsertionPointToStart(lp.getBody());
1219-
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
1220-
loops.push_back(lp);
1221-
}
1222-
explicitIterSpace.setOuterLoop(loops[0]);
1223-
if (const auto &mask =
1224-
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
1225-
header.t);
1226-
mask.has_value()) {
1227-
auto i1Ty = builder->getI1Type();
1228-
auto maskExv = genExprValue(*Fortran::semantics::GetExpr(mask.value()),
1229-
explicitIterSpace.stmtContext());
1230-
auto cond = builder->createConvert(loc, i1Ty, fir::getBase(maskExv));
1231-
auto ifOp = builder->create<fir::IfOp>(
1232-
loc, explicitIterSpace.innerArgTypes(), cond,
1233-
/*withElseRegion=*/true);
1234-
builder->create<fir::ResultOp>(loc, ifOp.getResults());
1235-
builder->setInsertionPointToStart(&ifOp.elseRegion().front());
1236-
builder->create<fir::ResultOp>(loc, explicitIterSpace.getInnerArgs());
1237-
builder->setInsertionPointToStart(&ifOp.thenRegion().front());
1187+
llvm::SmallVector<mlir::Value> lows;
1188+
llvm::SmallVector<mlir::Value> highs;
1189+
llvm::SmallVector<mlir::Value> steps;
1190+
if (explicitIterSpace.isOutermostForall()) {
1191+
// For the outermost forall, we evaluate the bounds expressions once.
1192+
// Contrastingly, if this forall is nested, the bounds expressions are
1193+
// assumed to be pure, possibly dependent on outer concurrent control
1194+
// variables, possibly variant with respect to arguments, and will be
1195+
// re-evaluated.
1196+
auto loc = toLocation();
1197+
auto idxTy = builder->getIndexType();
1198+
auto &stmtCtx = explicitIterSpace.stmtContext();
1199+
auto lowerExpr = [&](auto &e) {
1200+
return fir::getBase(genExprValue(e, stmtCtx));
1201+
};
1202+
for (auto &ctrl :
1203+
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1204+
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1205+
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1206+
auto &optStep =
1207+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1208+
lows.push_back(builder->createConvert(loc, idxTy, lowerExpr(*lo)));
1209+
highs.push_back(builder->createConvert(loc, idxTy, lowerExpr(*hi)));
1210+
steps.push_back(
1211+
optStep.has_value()
1212+
? builder->createConvert(
1213+
loc, idxTy,
1214+
lowerExpr(*Fortran::semantics::GetExpr(*optStep)))
1215+
: builder->createIntegerConstant(loc, idxTy, 1));
1216+
}
12381217
}
1218+
auto lambda = [&, lows, highs, steps]() {
1219+
// Create our iteration space from the header spec.
1220+
auto loc = toLocation();
1221+
auto idxTy = builder->getIndexType();
1222+
llvm::SmallVector<fir::DoLoopOp> loops;
1223+
auto &stmtCtx = explicitIterSpace.stmtContext();
1224+
auto lowerExpr = [&](auto &e) {
1225+
return fir::getBase(genExprValue(e, stmtCtx));
1226+
};
1227+
const auto outermost = !lows.empty();
1228+
std::size_t headerIndex = 0;
1229+
for (auto &ctrl :
1230+
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1231+
const auto *ctrlVar = std::get<Fortran::parser::Name>(ctrl.t).symbol;
1232+
mlir::Value lb;
1233+
mlir::Value ub;
1234+
mlir::Value by;
1235+
if (outermost) {
1236+
assert(headerIndex < lows.size());
1237+
lb = lows[headerIndex];
1238+
ub = highs[headerIndex];
1239+
by = steps[headerIndex++];
1240+
} else {
1241+
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1242+
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1243+
auto &optStep =
1244+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1245+
lb = builder->createConvert(loc, idxTy, lowerExpr(*lo));
1246+
ub = builder->createConvert(loc, idxTy, lowerExpr(*hi));
1247+
by = optStep.has_value()
1248+
? builder->createConvert(
1249+
loc, idxTy,
1250+
lowerExpr(*Fortran::semantics::GetExpr(*optStep)))
1251+
: builder->createIntegerConstant(loc, idxTy, 1);
1252+
}
1253+
auto lp = builder->create<fir::DoLoopOp>(
1254+
loc, lb, ub, by, /*unordered=*/true,
1255+
/*finalCount=*/false, explicitIterSpace.getInnerArgs());
1256+
if (!loops.empty() || !outermost)
1257+
builder->create<fir::ResultOp>(loc, lp.getResults());
1258+
explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
1259+
builder->setInsertionPointToStart(lp.getBody());
1260+
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
1261+
loops.push_back(lp);
1262+
}
1263+
explicitIterSpace.setOuterLoop(loops[0]);
1264+
if (const auto &mask =
1265+
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
1266+
header.t);
1267+
mask.has_value()) {
1268+
auto i1Ty = builder->getI1Type();
1269+
auto maskExv =
1270+
genExprValue(*Fortran::semantics::GetExpr(mask.value()), stmtCtx);
1271+
auto cond = builder->createConvert(loc, i1Ty, fir::getBase(maskExv));
1272+
auto ifOp = builder->create<fir::IfOp>(
1273+
loc, explicitIterSpace.innerArgTypes(), cond,
1274+
/*withElseRegion=*/true);
1275+
builder->create<fir::ResultOp>(loc, ifOp.getResults());
1276+
builder->setInsertionPointToStart(&ifOp.elseRegion().front());
1277+
builder->create<fir::ResultOp>(loc, explicitIterSpace.getInnerArgs());
1278+
builder->setInsertionPointToStart(&ifOp.thenRegion().front());
1279+
}
1280+
};
1281+
// Push the lambda to gen the loop nest context.
1282+
explicitIterSpace.pushLoopNest(lambda);
12391283
}
12401284

12411285
void genFIR(const Fortran::parser::ForallAssignmentStmt &stmt) {
12421286
std::visit([&](const auto &x) { genFIR(x); }, stmt.u);
12431287
}
12441288

12451289
void genFIR(const Fortran::parser::EndForallStmt &) {
1246-
explicitIterSpace.finalize();
12471290
cleanupExplicitSpace();
12481291
}
12491292

12501293
template <typename A>
12511294
void prepareExplicitSpace(const A &forall) {
1252-
analyzeExplicitSpace(forall);
1295+
if (!explicitIterSpace.isActive())
1296+
analyzeExplicitSpace(forall);
1297+
localSymbols.pushScope();
12531298
explicitIterSpace.enter();
1254-
Fortran::lower::createArrayLoads(*this, explicitIterSpace, localSymbols);
12551299
}
12561300

12571301
/// Cleanup all the FORALL context information when we exit.
12581302
void cleanupExplicitSpace() {
1259-
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
1260-
explicitIterSpace.conditionalCleanup();
1303+
explicitIterSpace.leave();
12611304
localSymbols.popScope();
12621305
}
12631306

@@ -1797,6 +1840,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
17971840
void genAssignment(const Fortran::evaluate::Assignment &assign) {
17981841
Fortran::lower::StatementContext stmtCtx;
17991842
auto loc = toLocation();
1843+
if (explicitIterationSpace()) {
1844+
Fortran::lower::createArrayLoads(*this, explicitIterSpace, localSymbols);
1845+
explicitIterSpace.genLoopNest();
1846+
}
18001847
std::visit(
18011848
Fortran::common::visitors{
18021849
// [1] Plain old assignment.
@@ -1893,7 +1940,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
18931940
if (implicitIterationSpace())
18941941
TODO(loc, "user defined assignment within WHERE");
18951942
Fortran::semantics::SomeExpr expr{procRef};
1896-
createFIRExpr(toLocation(), &expr, stmtCtx);
1943+
createFIRExpr(toLocation(), &expr,
1944+
explicitIterationSpace()
1945+
? explicitIterSpace.stmtContext()
1946+
: stmtCtx);
18971947
},
18981948

18991949
// [3] Pointer assignment with possibly empty bounds-spec. R1035: a
@@ -1954,6 +2004,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19542004
},
19552005
},
19562006
assign.u);
2007+
if (explicitIterationSpace())
2008+
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
19572009
}
19582010

19592011
void genFIR(const Fortran::parser::WhereConstruct &c) {
@@ -2536,6 +2588,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25362588
void analyzeExplicitSpace(const Fortran::evaluate::Assignment *assign) {
25372589
analyzeExplicitSpace</*LHS=*/true>(assign->lhs);
25382590
analyzeExplicitSpace(assign->rhs);
2591+
explicitIterSpace.endAssign();
25392592
}
25402593
void analyzeExplicitSpace(const Fortran::parser::ForallAssignmentStmt &stmt) {
25412594
std::visit([&](const auto &s) { analyzeExplicitSpace(s); }, stmt.u);
@@ -2666,7 +2719,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26662719
auto var = builder->createTemporary(loc, ty);
26672720
auto nil = builder->createNullConstant(loc, ty);
26682721
builder->create<fir::StoreOp>(loc, nil, var);
2669-
implicitIterSpace.addMaskVariable(exp, var);
2722+
auto shTy = fir::HeapType::get(builder->getIndexType());
2723+
auto shape = builder->createTemporary(loc, shTy);
2724+
auto nilSh = builder->createNullConstant(loc, shTy);
2725+
builder->create<fir::StoreOp>(loc, nilSh, shape);
2726+
implicitIterSpace.addMaskVariable(exp, var, shape);
26702727
explicitIterSpace.outermostContext().attachCleanup([=]() {
26712728
auto load = builder->create<fir::LoadOp>(loc, var);
26722729
auto cmp = builder->genIsNotNull(loc, load);

0 commit comments

Comments
 (0)