Skip to content

Commit 0ba3e96

Browse files
[mlir][SCF][NFC] ValueBoundsConstraintSet: Simplify scf.for implementation (#87862)
This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation. Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop. Note: This is a re-upload of #86239.
1 parent 7702023 commit 0ba3e96

File tree

1 file changed

+36
-44
lines changed

1 file changed

+36
-44
lines changed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1313

1414
using namespace mlir;
15-
using presburger::BoundType;
1615

1716
namespace mlir {
1817
namespace scf {
@@ -21,7 +20,28 @@ namespace {
2120
struct ForOpInterface
2221
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2322

24-
/// Populate bounds of values/dimensions for iter_args/OpResults.
23+
/// Populate bounds of values/dimensions for iter_args/OpResults. If the
24+
/// value/dimension size does not change in an iteration, we can deduce that
25+
/// it the same as the initial value/dimension.
26+
///
27+
/// Example 1:
28+
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29+
/// ...
30+
/// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31+
/// scf.yield %1 : tensor<?xf32>
32+
/// }
33+
/// --> bound(%0)[0] == bound(%t)[0]
34+
/// --> bound(%arg0)[0] == bound(%t)[0]
35+
///
36+
/// Example 2:
37+
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38+
/// %sz = tensor.dim %arg0 : tensor<?xf32>
39+
/// %incr = arith.addi %sz, %c1 : index
40+
/// %1 = tensor.empty(%incr) : tensor<?xf32>
41+
/// scf.yield %1 : tensor<?xf32>
42+
/// }
43+
/// --> The yielded tensor dimension size changes with each iteration. Such
44+
/// loops are not supported and no constraints are added.
2545
static void populateIterArgBounds(scf::ForOp forOp, Value value,
2646
std::optional<int64_t> dim,
2747
ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
3353
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
3454
}
3555

36-
// An EQ constraint can be added if the yielded value (dimension size)
37-
// equals the corresponding block argument (dimension size).
3856
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
3957
.getOperand(iterArgIdx);
4058
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
4159
Value initArg = forOp.getInitArgs()[iterArgIdx];
4260

43-
auto addEqBound = [&]() {
61+
// Populate constraints for the yielded value.
62+
cstr.populateConstraints(yieldedValue, dim);
63+
// Populate constraints for the iter_arg. This is just to ensure that the
64+
// iter_arg is mapped in the constraint set, which is a prerequisite for
65+
// `compare`. It may lead to a recursive call to this function in case the
66+
// iter_arg was not visited when the constraints for the yielded value were
67+
// populated, but no additional work is done.
68+
cstr.populateConstraints(iterArg, dim);
69+
70+
// An EQ constraint can be added if the yielded value (dimension size)
71+
// equals the corresponding block argument (dimension size).
72+
if (cstr.compare(yieldedValue, dim,
73+
ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74+
dim)) {
4475
if (dim.has_value()) {
4576
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
4677
} else {
4778
cstr.bound(value) == initArg;
4879
}
49-
};
50-
51-
if (yieldedValue == iterArg) {
52-
addEqBound();
53-
return;
54-
}
55-
56-
// Compute EQ bound for yielded value.
57-
AffineMap bound;
58-
ValueDimList boundOperands;
59-
LogicalResult status = ValueBoundsConstraintSet::computeBound(
60-
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61-
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
62-
// Stop when reaching a block argument of the loop body.
63-
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64-
return bbArg.getOwner()->getParentOp() == forOp;
65-
// Stop when reaching a value that is defined outside of the loop. It
66-
// is impossible to reach an iter_arg from there.
67-
Operation *op = v.getDefiningOp();
68-
return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
69-
});
70-
if (failed(status))
71-
return;
72-
if (bound.getNumResults() != 1)
73-
return;
74-
75-
// Check if computed bound equals the corresponding iter_arg.
76-
Value singleValue = nullptr;
77-
std::optional<int64_t> singleDim;
78-
if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
79-
int64_t idx = dimExpr.getPosition();
80-
singleValue = boundOperands[idx].first;
81-
singleDim = boundOperands[idx].second;
82-
} else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
83-
int64_t idx = symExpr.getPosition() + bound.getNumDims();
84-
singleValue = boundOperands[idx].first;
85-
singleDim = boundOperands[idx].second;
8680
}
87-
if (singleValue == iterArg && singleDim == dim)
88-
addEqBound();
8981
}
9082

9183
void populateBoundsForIndexValue(Operation *op, Value value,

0 commit comments

Comments
 (0)