1212#include " mlir/Interfaces/ValueBoundsOpInterface.h"
1313
1414using namespace mlir ;
15- using presburger::BoundType;
1615
1716namespace mlir {
1817namespace scf {
@@ -21,7 +20,28 @@ namespace {
2120struct ForOpInterface
2221 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2322
24- // / Populate bounds of values/dimensions for iter_args/OpResults.
23+ // / Populate bounds of values/dimensions for iter_args/OpResults. If the
24+ // / value/dimension size does not change in an iteration, we can deduce that
25+ // / it the same as the initial value/dimension.
26+ // /
27+ // / Example 1:
28+ // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29+ // / ...
30+ // / %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31+ // / scf.yield %1 : tensor<?xf32>
32+ // / }
33+ // / --> bound(%0)[0] == bound(%t)[0]
34+ // / --> bound(%arg0)[0] == bound(%t)[0]
35+ // /
36+ // / Example 2:
37+ // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38+ // / %sz = tensor.dim %arg0 : tensor<?xf32>
39+ // / %incr = arith.addi %sz, %c1 : index
40+ // / %1 = tensor.empty(%incr) : tensor<?xf32>
41+ // / scf.yield %1 : tensor<?xf32>
42+ // / }
43+ // / --> The yielded tensor dimension size changes with each iteration. Such
44+ // / loops are not supported and no constraints are added.
2545 static void populateIterArgBounds (scf::ForOp forOp, Value value,
2646 std::optional<int64_t > dim,
2747 ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
3353 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber ();
3454 }
3555
36- // An EQ constraint can be added if the yielded value (dimension size)
37- // equals the corresponding block argument (dimension size).
3856 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ())
3957 .getOperand (iterArgIdx);
4058 Value iterArg = forOp.getRegionIterArg (iterArgIdx);
4159 Value initArg = forOp.getInitArgs ()[iterArgIdx];
4260
43- auto addEqBound = [&]() {
61+ // Populate constraints for the yielded value.
62+ cstr.populateConstraints (yieldedValue, dim);
63+ // Populate constraints for the iter_arg. This is just to ensure that the
64+ // iter_arg is mapped in the constraint set, which is a prerequisite for
65+ // `compare`. It may lead to a recursive call to this function in case the
66+ // iter_arg was not visited when the constraints for the yielded value were
67+ // populated, but no additional work is done.
68+ cstr.populateConstraints (iterArg, dim);
69+
70+ // An EQ constraint can be added if the yielded value (dimension size)
71+ // equals the corresponding block argument (dimension size).
72+ if (cstr.compare (yieldedValue, dim,
73+ ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74+ dim)) {
4475 if (dim.has_value ()) {
4576 cstr.bound (value)[*dim] == cstr.getExpr (initArg, dim);
4677 } else {
4778 cstr.bound (value) == initArg;
4879 }
49- };
50-
51- if (yieldedValue == iterArg) {
52- addEqBound ();
53- return ;
54- }
55-
56- // Compute EQ bound for yielded value.
57- AffineMap bound;
58- ValueDimList boundOperands;
59- LogicalResult status = ValueBoundsConstraintSet::computeBound (
60- bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61- [&](Value v, std::optional<int64_t > d, ValueBoundsConstraintSet &cstr) {
62- // Stop when reaching a block argument of the loop body.
63- if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64- return bbArg.getOwner ()->getParentOp () == forOp;
65- // Stop when reaching a value that is defined outside of the loop. It
66- // is impossible to reach an iter_arg from there.
67- Operation *op = v.getDefiningOp ();
68- return forOp.getRegion ().findAncestorOpInRegion (*op) == nullptr ;
69- });
70- if (failed (status))
71- return ;
72- if (bound.getNumResults () != 1 )
73- return ;
74-
75- // Check if computed bound equals the corresponding iter_arg.
76- Value singleValue = nullptr ;
77- std::optional<int64_t > singleDim;
78- if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult (0 ))) {
79- int64_t idx = dimExpr.getPosition ();
80- singleValue = boundOperands[idx].first ;
81- singleDim = boundOperands[idx].second ;
82- } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult (0 ))) {
83- int64_t idx = symExpr.getPosition () + bound.getNumDims ();
84- singleValue = boundOperands[idx].first ;
85- singleDim = boundOperands[idx].second ;
8680 }
87- if (singleValue == iterArg && singleDim == dim)
88- addEqBound ();
8981 }
9082
9183 void populateBoundsForIndexValue (Operation *op, Value value,
0 commit comments