Skip to content

Commit db3dde1

Browse files
[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Pass stop condition in the constructor
This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed. This change is in preparation of adding support for branches.
1 parent 12e7e88 commit db3dde1

File tree

7 files changed

+62
-39
lines changed

7 files changed

+62
-39
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ class ValueBoundsConstraintSet {
113113
///
114114
/// The first parameter of the function is the shaped value/index-typed
115115
/// value. The second parameter is the dimension in case of a shaped value.
116-
using StopConditionFn =
117-
function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
116+
/// The third parameter is this constraint set.
117+
using StopConditionFn = function_ref<bool(
118+
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
118119

119120
/// Compute a bound for the given index-typed value or shape dimension size.
120121
/// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -263,12 +264,12 @@ class ValueBoundsConstraintSet {
263264
/// An index-typed value or the dimension of a shaped-type value.
264265
using ValueDim = std::pair<Value, int64_t>;
265266

266-
ValueBoundsConstraintSet(MLIRContext *ctx);
267+
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
267268

268269
/// Iteratively process all elements on the worklist until an index-typed
269-
/// value or shaped value meets `stopCondition`. Such values are not processed
270-
/// any further.
271-
void processWorklist(StopConditionFn stopCondition);
270+
/// value or shaped value meets `currentStopCondition`. Such values are not
271+
/// processed any further.
272+
void processWorklist();
272273

273274
/// Bound the given column in the underlying constraint set by the given
274275
/// expression.
@@ -316,6 +317,9 @@ class ValueBoundsConstraintSet {
316317

317318
/// Builder for constructing affine expressions.
318319
Builder builder;
320+
321+
/// The current stop condition function.
322+
StopConditionFn stopCondition = nullptr;
319323
};
320324

321325
} // namespace mlir

mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
8484
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
8585
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
8686
bool closedUB) {
87-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
87+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
88+
ValueBoundsConstraintSet &cstr) {
8889
// We are trying to reify a bound for `value` in terms of the owning op's
8990
// operands. Construct a stop condition that evaluates to "true" for any SSA
9091
// value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
100101
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
101102
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
102103
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
103-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
104+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
105+
ValueBoundsConstraintSet &cstr) {
104106
return v != value;
105107
};
106108
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,

mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
119119
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
120120
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
121121
bool closedUB) {
122-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
122+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
123+
ValueBoundsConstraintSet &cstr) {
123124
// We are trying to reify a bound for `value` in terms of the owning op's
124125
// operands. Construct a stop condition that evaluates to "true" for any SSA
125126
// value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
135136
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
136137
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
137138
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
138-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
139+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
140+
ValueBoundsConstraintSet &cstr) {
139141
return v != value;
140142
};
141143
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
468468
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
469469
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
470470
/*stopCondition=*/
471-
[&](Value v, std::optional<int64_t> d) {
471+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
472472
if (v == forOp.getUpperBound())
473473
return false;
474474
// Compute a bound that is independent of any affine op results.

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct ForOpInterface
5858
ValueDimList boundOperands;
5959
LogicalResult status = ValueBoundsConstraintSet::computeBound(
6060
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61-
[&](Value v, std::optional<int64_t> d) {
61+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
6262
// Stop when reaching a block argument of the loop body.
6363
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
6464
return bbArg.getOwner()->getParentOp() == forOp;

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
6767
return std::nullopt;
6868
}
6969

70-
ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
71-
: builder(ctx) {}
70+
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
71+
MLIRContext *ctx, StopConditionFn stopCondition)
72+
: builder(ctx), stopCondition(stopCondition) {}
7273

7374
#ifndef NDEBUG
7475
static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
@@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) {
228229
return value.getDefiningOp();
229230
}
230231

231-
void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
232+
void ValueBoundsConstraintSet::processWorklist() {
233+
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
232234
while (!worklist.empty()) {
233235
int64_t pos = worklist.front();
234236
worklist.pop();
@@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
249251

250252
// Do not process any further if the stop condition is met.
251253
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
252-
if (stopCondition(value, maybeDim))
254+
if (stopCondition(value, maybeDim, *this)) {
255+
LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
256+
<< " (dim: " << maybeDim << ")\n");
253257
continue;
258+
}
254259

255260
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
256261
// the worklist.
257262
auto valueBoundsOp =
258263
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
264+
LLVM_DEBUG(llvm::dbgs()
265+
<< "Query value bounds for: " << value
266+
<< " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
259267
if (valueBoundsOp) {
260268
if (dim == kIndexValue) {
261269
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
264272
}
265273
continue;
266274
}
275+
LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
267276

268277
// If the op does not implement `ValueBoundsOpInterface`, check if it
269278
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
313322
bool closedUB) {
314323
#ifndef NDEBUG
315324
assertValidValueDim(value, dim);
316-
assert(!stopCondition(value, dim) &&
317-
"stop condition should not be satisfied for starting point");
318325
#endif // NDEBUG
319326

320327
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
324331
// Process the backward slice of `value` (i.e., reverse use-def chain) until
325332
// `stopCondition` is met.
326333
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
327-
ValueBoundsConstraintSet cstr(value.getContext());
334+
ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
335+
assert(!stopCondition(value, dim, cstr) &&
336+
"stop condition should not be satisfied for starting point");
328337
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
329-
cstr.processWorklist(stopCondition);
338+
cstr.processWorklist();
330339

331340
// Project out all variables (apart from `valueDim`) that do not match the
332341
// stop condition.
@@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
336345
return false;
337346
auto maybeDim =
338347
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
339-
return !stopCondition(p.first, maybeDim);
348+
return !stopCondition(p.first, maybeDim, cstr);
340349
});
341350

342351
// Compute lower and upper bounds for `valueDim`.
@@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
442451
bool closedUB) {
443452
return computeBound(
444453
resultMap, mapOperands, type, value, dim,
445-
[&](Value v, std::optional<int64_t> d) {
454+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
446455
return llvm::is_contained(dependencies, std::make_pair(v, d));
447456
},
448457
closedUB);
@@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
478487
// Reify bounds in terms of any independent values.
479488
return computeBound(
480489
resultMap, mapOperands, type, value, dim,
481-
[&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
490+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
491+
return isIndependent(v);
492+
},
482493
closedUB);
483494
}
484495

@@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
500511
presburger::BoundType type, AffineMap map, ValueDimList operands,
501512
StopConditionFn stopCondition, bool closedUB) {
502513
assert(map.getNumResults() == 1 && "expected affine map with one result");
503-
ValueBoundsConstraintSet cstr(map.getContext());
504-
int64_t pos = cstr.insert(/*isSymbol=*/false);
514+
515+
// Default stop condition if none was specified: Keep adding constraints until
516+
// a bound could be computed.
517+
int64_t pos;
518+
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
519+
ValueBoundsConstraintSet &cstr) {
520+
return cstr.cstr.getConstantBound64(type, pos).has_value();
521+
};
522+
523+
ValueBoundsConstraintSet cstr(
524+
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
525+
pos = cstr.insert(/*isSymbol=*/false);
505526

506527
// Add map and operands to the constraint set. Dimensions are converted to
507528
// symbols. All operands are added to the worklist.
@@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
517538
map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
518539

519540
// Process the backward slice of `operands` (i.e., reverse use-def chain)
520-
// until `stopCondition` is met.
521-
if (stopCondition) {
522-
cstr.processWorklist(stopCondition);
523-
} else {
524-
// No stop condition specified: Keep adding constraints until a bound could
525-
// be computed.
526-
cstr.processWorklist(
527-
/*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
528-
return cstr.cstr.getConstantBound64(type, pos).has_value();
529-
});
530-
}
541+
// until the stop condition is met.
542+
cstr.processWorklist();
531543

532544
// Compute constant bound for `valueDim`.
533545
int64_t ubAdjustment = closedUB ? 0 : 1;

mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
112112

113113
// Prepare stop condition. By default, reify in terms of the op's
114114
// operands. No stop condition is used when a constant was requested.
115-
std::function<bool(Value, std::optional<int64_t>)> stopCondition =
116-
[&](Value v, std::optional<int64_t> d) {
115+
std::function<bool(Value, std::optional<int64_t>,
116+
ValueBoundsConstraintSet & cstr)>
117+
stopCondition = [&](Value v, std::optional<int64_t> d,
118+
ValueBoundsConstraintSet &cstr) {
117119
// Reify in terms of SSA values that are different from `value`.
118120
return v != value;
119121
};
120122
if (reifyToFuncArgs) {
121123
// Reify in terms of function block arguments.
122-
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
124+
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d,
125+
ValueBoundsConstraintSet &cstr) {
123126
auto bbArg = dyn_cast<BlockArgument>(v);
124127
if (!bbArg)
125128
return false;

0 commit comments

Comments
 (0)