12
12
#include " mlir/Interfaces/ValueBoundsOpInterface.h"
13
13
14
14
using namespace mlir ;
15
- using presburger::BoundType;
16
15
17
16
namespace mlir {
18
17
namespace scf {
@@ -21,7 +20,28 @@ namespace {
21
20
struct ForOpInterface
22
21
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
23
22
24
- // / Populate bounds of values/dimensions for iter_args/OpResults.
23
+ // / Populate bounds of values/dimensions for iter_args/OpResults. If the
24
+ // / value/dimension size does not change in an iteration, we can deduce that
25
+ // / it the same as the initial value/dimension.
26
+ // /
27
+ // / Example 1:
28
+ // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29
+ // / ...
30
+ // / %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31
+ // / scf.yield %1 : tensor<?xf32>
32
+ // / }
33
+ // / --> bound(%0)[0] == bound(%t)[0]
34
+ // / --> bound(%arg0)[0] == bound(%t)[0]
35
+ // /
36
+ // / Example 2:
37
+ // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38
+ // / %sz = tensor.dim %arg0 : tensor<?xf32>
39
+ // / %incr = arith.addi %sz, %c1 : index
40
+ // / %1 = tensor.empty(%incr) : tensor<?xf32>
41
+ // / scf.yield %1 : tensor<?xf32>
42
+ // / }
43
+ // / --> The yielded tensor dimension size changes with each iteration. Such
44
+ // / loops are not supported and no constraints are added.
25
45
static void populateIterArgBounds (scf::ForOp forOp, Value value,
26
46
std::optional<int64_t > dim,
27
47
ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
33
53
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber ();
34
54
}
35
55
36
- // An EQ constraint can be added if the yielded value (dimension size)
37
- // equals the corresponding block argument (dimension size).
38
56
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ())
39
57
.getOperand (iterArgIdx);
40
58
Value iterArg = forOp.getRegionIterArg (iterArgIdx);
41
59
Value initArg = forOp.getInitArgs ()[iterArgIdx];
42
60
43
- auto addEqBound = [&]() {
61
+ // Populate constraints for the yielded value.
62
+ cstr.populateConstraints (yieldedValue, dim);
63
+ // Populate constraints for the iter_arg. This is just to ensure that the
64
+ // iter_arg is mapped in the constraint set, which is a prerequisite for
65
+ // `compare`. It may lead to a recursive call to this function in case the
66
+ // iter_arg was not visited when the constraints for the yielded value were
67
+ // populated, but no additional work is done.
68
+ cstr.populateConstraints (iterArg, dim);
69
+
70
+ // An EQ constraint can be added if the yielded value (dimension size)
71
+ // equals the corresponding block argument (dimension size).
72
+ if (cstr.compare (yieldedValue, dim,
73
+ ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74
+ dim)) {
44
75
if (dim.has_value ()) {
45
76
cstr.bound (value)[*dim] == cstr.getExpr (initArg, dim);
46
77
} else {
47
78
cstr.bound (value) == initArg;
48
79
}
49
- };
50
-
51
- if (yieldedValue == iterArg) {
52
- addEqBound ();
53
- return ;
54
- }
55
-
56
- // Compute EQ bound for yielded value.
57
- AffineMap bound;
58
- ValueDimList boundOperands;
59
- LogicalResult status = ValueBoundsConstraintSet::computeBound (
60
- bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61
- [&](Value v, std::optional<int64_t > d, ValueBoundsConstraintSet &cstr) {
62
- // Stop when reaching a block argument of the loop body.
63
- if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64
- return bbArg.getOwner ()->getParentOp () == forOp;
65
- // Stop when reaching a value that is defined outside of the loop. It
66
- // is impossible to reach an iter_arg from there.
67
- Operation *op = v.getDefiningOp ();
68
- return forOp.getRegion ().findAncestorOpInRegion (*op) == nullptr ;
69
- });
70
- if (failed (status))
71
- return ;
72
- if (bound.getNumResults () != 1 )
73
- return ;
74
-
75
- // Check if computed bound equals the corresponding iter_arg.
76
- Value singleValue = nullptr ;
77
- std::optional<int64_t > singleDim;
78
- if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult (0 ))) {
79
- int64_t idx = dimExpr.getPosition ();
80
- singleValue = boundOperands[idx].first ;
81
- singleDim = boundOperands[idx].second ;
82
- } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult (0 ))) {
83
- int64_t idx = symExpr.getPosition () + bound.getNumDims ();
84
- singleValue = boundOperands[idx].first ;
85
- singleDim = boundOperands[idx].second ;
86
80
}
87
- if (singleValue == iterArg && singleDim == dim)
88
- addEqBound ();
89
81
}
90
82
91
83
void populateBoundsForIndexValue (Operation *op, Value value,
0 commit comments