Skip to content

Commit c7c5666

Browse files
authored
[flang] Do not hoist all scalar sub-expressions from WHERE constructs (#91395)
The HLFIR pass lowering WHERE (hlfir.where op) was too aggressive in its hoisting of scalar sub-expressions from LHS/RHS/MASKS outside of the loops generated for the WHERE construct. This violated F'2023 10.2.3.2 point 10 that stipulated that elemental operations must be evaluated only for elements corresponding to true values, because scalar operations are still elemental, and hoisting them is invalid if they could have side effects (e.g, division by zero) and if the MASK is always false (i.e., the loop body is never evaluated). The difficulty is that 10.2.3.2 point 9 mandates that nonelemental function must be evaluated before the loops. So it is not possible to simply stop hoisting non hlfir.elemental operations. Marking calls with an elemental/nonelemental attribute would not allow the pass to be correct if inlining is run before and drops this information, beside, extracting the argument tree that may have been CSE-ed with the rest of the expression evaluation would be a bit combursome. Instead, lower nonelemental calls into a new hlfir.exactly_once operation that will allow retaining the information that the operations contained inside its region must be hoisted. This allows inlining to operate before if desired in order to improve alias analysis. The LowerHLFIROrderedAssignments pass is updated to only hoist the operations contained inside hlfir.exactly_once bodies.
1 parent e6d3a42 commit c7c5666

File tree

12 files changed

+525
-64
lines changed

12 files changed

+525
-64
lines changed

flang/include/flang/Lower/StatementContext.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
#include <functional>
1919
#include <optional>
2020

21+
namespace mlir {
22+
class Location;
23+
class Region;
24+
} // namespace mlir
25+
26+
namespace fir {
27+
class FirOpBuilder;
28+
}
29+
2130
namespace Fortran::lower {
2231

2332
/// When lowering a statement, temporaries for intermediate results may be
@@ -105,6 +114,11 @@ class StatementContext {
105114
llvm::SmallVector<std::optional<CleanupFunction>> cufs;
106115
};
107116

117+
/// If \p context contains any cleanups, ensure \p region has a block, and
118+
/// generate the cleanup inside that block.
119+
void genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
120+
mlir::Region &region, StatementContext &context);
121+
108122
} // namespace Fortran::lower
109123

110124
#endif // FORTRAN_LOWER_STATEMENTCONTEXT_H

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,8 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre
13301330
}
13311331

13321332
def hlfir_YieldOp : hlfir_Op<"yield", [Terminator, ParentOneOf<["RegionAssignOp",
1333-
"ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp"]>,
1333+
"ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp",
1334+
"ExactlyOnceOp"]>,
13341335
SingleBlockImplicitTerminator<"fir::FirEndOp">, RecursivelySpeculatable,
13351336
RecursiveMemoryEffects]> {
13361337

@@ -1595,6 +1596,27 @@ def hlfir_ForallMaskOp : hlfir_AssignmentMaskOp<"forall_mask"> {
15951596
let hasVerifier = 1;
15961597
}
15971598

1599+
def hlfir_ExactlyOnceOp : hlfir_Op<"exactly_once", [RecursiveMemoryEffects]> {
1600+
let summary = "Execute exactly once its region in a WhereOp";
1601+
let description = [{
1602+
Inside a Where assignment, Fortran requires a non elemental call and its
1603+
arguments to be executed exactly once, regardless of the mask values.
1604+
This operation allows holding these evaluations that cannot be hoisted
1605+
until potential parent Forall loops have been created.
1606+
It also allows inlining the calls without losing the information that
1607+
these calls must be hoisted.
1608+
}];
1609+
1610+
let regions = (region SizedRegion<1>:$body);
1611+
1612+
let results = (outs AnyFortranEntity:$result);
1613+
1614+
let assemblyFormat = [{
1615+
attr-dict `:` type($result)
1616+
$body
1617+
}];
1618+
}
1619+
15981620
def hlfir_WhereOp : hlfir_AssignmentMaskOp<"where"> {
15991621
let summary = "Represent a Fortran where construct or statement";
16001622
let description = [{

flang/lib/Lower/Bridge.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3687,22 +3687,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
36873687
return hlfir::Entity{valueAndPair.first};
36883688
}
36893689

3690-
static void
3691-
genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
3692-
mlir::Region &region,
3693-
Fortran::lower::StatementContext &context) {
3694-
if (!context.hasCode())
3695-
return;
3696-
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
3697-
if (region.empty())
3698-
builder.createBlock(&region);
3699-
else
3700-
builder.setInsertionPointToEnd(&region.front());
3701-
context.finalizeAndPop();
3702-
hlfir::YieldOp::ensureTerminator(region, builder, loc);
3703-
builder.restoreInsertionPoint(insertPt);
3704-
}
3705-
37063690
bool firstDummyIsPointerOrAllocatable(
37073691
const Fortran::evaluate::ProcedureRef &userDefinedAssignment) {
37083692
using DummyAttr = Fortran::evaluate::characteristics::DummyDataObject::Attr;
@@ -3928,23 +3912,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
39283912
Fortran::lower::StatementContext rhsContext;
39293913
hlfir::Entity rhs = evaluateRhs(rhsContext);
39303914
auto rhsYieldOp = builder.create<hlfir::YieldOp>(loc, rhs);
3931-
genCleanUpInRegionIfAny(loc, builder, rhsYieldOp.getCleanup(), rhsContext);
3915+
Fortran::lower::genCleanUpInRegionIfAny(
3916+
loc, builder, rhsYieldOp.getCleanup(), rhsContext);
39323917
// Lower LHS in its own region.
39333918
builder.createBlock(&regionAssignOp.getLhsRegion());
39343919
Fortran::lower::StatementContext lhsContext;
39353920
mlir::Value lhsYield = nullptr;
39363921
if (!lhsHasVectorSubscripts) {
39373922
hlfir::Entity lhs = evaluateLhs(lhsContext);
39383923
auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
3939-
genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
3940-
lhsContext);
3924+
Fortran::lower::genCleanUpInRegionIfAny(
3925+
loc, builder, lhsYieldOp.getCleanup(), lhsContext);
39413926
lhsYield = lhs;
39423927
} else {
39433928
hlfir::ElementalAddrOp elementalAddr =
39443929
Fortran::lower::convertVectorSubscriptedExprToElementalAddr(
39453930
loc, *this, assign.lhs, localSymbols, lhsContext);
3946-
genCleanUpInRegionIfAny(loc, builder, elementalAddr.getCleanup(),
3947-
lhsContext);
3931+
Fortran::lower::genCleanUpInRegionIfAny(
3932+
loc, builder, elementalAddr.getCleanup(), lhsContext);
39483933
lhsYield = elementalAddr.getYieldOp().getEntity();
39493934
}
39503935
assert(lhsYield && "must have been set");
@@ -4299,7 +4284,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42994284
loc, *this, *maskExpr, localSymbols, maskContext);
43004285
mask = hlfir::loadTrivialScalar(loc, *builder, mask);
43014286
auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
4302-
genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
4287+
Fortran::lower::genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(),
4288+
maskContext);
43034289
}
43044290
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
43054291
const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
@@ -5599,3 +5585,18 @@ Fortran::lower::LoweringBridge::LoweringBridge(
55995585
fir::support::setMLIRDataLayout(*module.get(),
56005586
targetMachine.createDataLayout());
56015587
}
5588+
5589+
void Fortran::lower::genCleanUpInRegionIfAny(
5590+
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Region &region,
5591+
Fortran::lower::StatementContext &context) {
5592+
if (!context.hasCode())
5593+
return;
5594+
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
5595+
if (region.empty())
5596+
builder.createBlock(&region);
5597+
else
5598+
builder.setInsertionPointToEnd(&region.front());
5599+
context.finalizeAndPop();
5600+
hlfir::YieldOp::ensureTerminator(region, builder, loc);
5601+
builder.restoreInsertionPoint(insertPt);
5602+
}

flang/lib/Lower/ConvertCall.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,10 +2682,48 @@ bool Fortran::lower::isIntrinsicModuleProcRef(
26822682
return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC);
26832683
}
26842684

2685+
static bool isInWhereMaskedExpression(fir::FirOpBuilder &builder) {
2686+
// The MASK of the outer WHERE is not masked itself.
2687+
mlir::Operation *op = builder.getRegion().getParentOp();
2688+
return op && op->getParentOfType<hlfir::WhereOp>();
2689+
}
2690+
26852691
std::optional<hlfir::EntityWithAttributes> Fortran::lower::convertCallToHLFIR(
26862692
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
26872693
const evaluate::ProcedureRef &procRef, std::optional<mlir::Type> resultType,
26882694
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
2695+
auto &builder = converter.getFirOpBuilder();
2696+
if (resultType && !procRef.IsElemental() &&
2697+
isInWhereMaskedExpression(builder) &&
2698+
!builder.getRegion().getParentOfType<hlfir::ExactlyOnceOp>()) {
2699+
// Non elemental calls inside a where-assignment-stmt must be executed
2700+
// exactly once without mask control. Lower them in a special region so that
2701+
// this can be enforced whenscheduling forall/where expression evaluations.
2702+
Fortran::lower::StatementContext localStmtCtx;
2703+
mlir::Type bogusType = builder.getIndexType();
2704+
auto exactlyOnce = builder.create<hlfir::ExactlyOnceOp>(loc, bogusType);
2705+
mlir::Block *block = builder.createBlock(&exactlyOnce.getBody());
2706+
builder.setInsertionPointToStart(block);
2707+
CallContext callContext(procRef, resultType, loc, converter, symMap,
2708+
localStmtCtx);
2709+
std::optional<hlfir::EntityWithAttributes> res =
2710+
genProcedureRef(callContext);
2711+
assert(res.has_value() && "must be a function");
2712+
auto yield = builder.create<hlfir::YieldOp>(loc, *res);
2713+
Fortran::lower::genCleanUpInRegionIfAny(loc, builder, yield.getCleanup(),
2714+
localStmtCtx);
2715+
builder.setInsertionPointAfter(exactlyOnce);
2716+
exactlyOnce->getResult(0).setType(res->getType());
2717+
if (hlfir::isFortranValue(exactlyOnce.getResult()))
2718+
return hlfir::EntityWithAttributes{exactlyOnce.getResult()};
2719+
// Create hlfir.declare for the result to satisfy
2720+
// hlfir::EntityWithAttributes requirements.
2721+
auto [exv, cleanup] = hlfir::translateToExtendedValue(
2722+
loc, builder, hlfir::Entity{exactlyOnce});
2723+
assert(!cleanup && "resut is a variable");
2724+
return hlfir::genDeclare(loc, builder, exv, ".func.pointer.result",
2725+
fir::FortranVariableFlagsAttr{});
2726+
}
26892727
CallContext callContext(procRef, resultType, loc, converter, symMap, stmtCtx);
26902728
return genProcedureRef(callContext);
26912729
}

0 commit comments

Comments
 (0)