Skip to content

Commit 0c94128

Browse files
[mlir][SCF] ValueBoundsConstraintSet: Support preliminary support for branches
This commit adds support for `scf.if` to `ValueBoundsConstraintSet`. Example: ``` %0 = scf.if ... -> index { scf.yield %a : index } else { scf.yield %b : index } ``` The following constraints hold for %0: * %0 >= min(%a, %b) * %0 <= max(%a, %b) Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b: * %0 >= %a * %0 <= %b This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
1 parent a4ca07f commit 0c94128

File tree

4 files changed

+329
-29
lines changed

4 files changed

+329
-29
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,28 @@ class ValueBoundsConstraintSet {
198198
std::optional<int64_t> dim1 = std::nullopt,
199199
std::optional<int64_t> dim2 = std::nullopt);
200200

201+
/// Traverse the IR starting from the given value/dim and add populate
202+
/// constraints as long as the currently set stop condition holds. Also
203+
/// processes all values/dims that are already on the worklist.
204+
void populateConstraints(Value value, std::optional<int64_t> dim);
205+
206+
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
207+
enum ComparisonOperator { LT, LE, EQ, GT, GE };
208+
209+
/// Try to prove that, based on the current state of this constraint set
210+
/// (i.e., without analyzing additional IR or adding new constraints), it can
211+
/// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the
212+
/// second given value/dim.
213+
///
214+
/// Return "true" if the specified relation between the two values/dims was
215+
/// proven to hold. Return "false" if the specified relation could not be
216+
/// proven. This could be because the specified relation does in fact not hold
217+
/// or because there is not enough information in the constraint set. In other
218+
/// words, if we do not know for sure, this function returns "false".
219+
bool compare(Value value1, std::optional<int64_t> dim1,
220+
ComparisonOperator cmp, Value value2,
221+
std::optional<int64_t> dim2);
222+
201223
/// Compute whether the given values/dimensions are equal. Return "failure" if
202224
/// equality could not be determined.
203225
///
@@ -266,9 +288,9 @@ class ValueBoundsConstraintSet {
266288
ValueBoundsConstraintSet(MLIRContext *ctx);
267289

268290
/// Iteratively process all elements on the worklist until an index-typed
269-
/// value or shaped value meets `stopCondition`. Such values are not processed
270-
/// any further.
271-
void processWorklist(StopConditionFn stopCondition);
291+
/// value or shaped value meets `currentStopCondition`. Such values are not
292+
/// processed any further.
293+
void processWorklist();
272294

273295
/// Bound the given column in the underlying constraint set by the given
274296
/// expression.
@@ -280,12 +302,13 @@ class ValueBoundsConstraintSet {
280302

281303
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
282304
/// "false", a dimension is added. The value/dimension is added to the
283-
/// worklist.
305+
/// worklist if `addToWorklist` is set.
284306
///
285307
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
286308
/// cannot be multiplied. Furthermore, bounds can only be queried for
287309
/// dimensions but not for symbols.
288-
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
310+
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
311+
bool addToWorklist = true);
289312

290313
/// Insert an anonymous column into the constraint set. The column is not
291314
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
@@ -315,6 +338,9 @@ class ValueBoundsConstraintSet {
315338

316339
/// Builder for constructing affine expressions.
317340
Builder builder;
341+
342+
/// The current stop condition function.
343+
StopConditionFn currentStopCondition = nullptr;
318344
};
319345

320346
} // namespace mlir

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct ForOpInterface
111111
}
112112
};
113113

114+
struct IfOpInterface
115+
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
116+
117+
void populateBoundsForIndexValue(Operation *op, Value value,
118+
ValueBoundsConstraintSet &cstr) const {
119+
auto ifOp = cast<IfOp>(op);
120+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
121+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
122+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
123+
124+
// Populate constraints for the yielded value (and all values on the
125+
// backward slice, as long as the current stop condition is not satisfied).
126+
cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt);
127+
cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt);
128+
129+
// Compare yielded values.
130+
// If thenValue <= elseValue:
131+
// * result <= elseValue
132+
// * result >= thenValue
133+
if (cstr.compare(thenValue, /*dim1=*/std::nullopt,
134+
ValueBoundsConstraintSet::ComparisonOperator::LE,
135+
elseValue, /*dim2=*/std::nullopt)) {
136+
cstr.bound(value) >= thenValue;
137+
cstr.bound(value) <= elseValue;
138+
}
139+
// If elseValue <= thenValue:
140+
// * result <= thenValue
141+
// * result >= elseValue
142+
if (cstr.compare(elseValue, /*dim1=*/std::nullopt,
143+
ValueBoundsConstraintSet::ComparisonOperator::LE,
144+
thenValue, /*dim2=*/std::nullopt)) {
145+
cstr.bound(value) >= elseValue;
146+
cstr.bound(value) <= thenValue;
147+
}
148+
}
149+
150+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
151+
ValueBoundsConstraintSet &cstr) const {
152+
// See `populateBoundsForIndexValue` for documentation.
153+
auto ifOp = cast<IfOp>(op);
154+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
155+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
156+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
157+
158+
cstr.populateConstraints(thenValue, dim);
159+
cstr.populateConstraints(elseValue, dim);
160+
161+
if (cstr.compare(thenValue, dim,
162+
ValueBoundsConstraintSet::ComparisonOperator::LE,
163+
elseValue, dim)) {
164+
cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim);
165+
cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim);
166+
}
167+
if (cstr.compare(elseValue, dim,
168+
ValueBoundsConstraintSet::ComparisonOperator::LE,
169+
thenValue, dim)) {
170+
cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim);
171+
cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim);
172+
}
173+
}
174+
};
175+
114176
} // namespace
115177
} // namespace scf
116178
} // namespace mlir
@@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
119181
DialectRegistry &registry) {
120182
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
121183
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
184+
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
122185
});
123186
}

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,41 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
105105
assertValidValueDim(value, dim);
106106
#endif // NDEBUG
107107

108+
auto getPosExpr = [&](int64_t pos) {
109+
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() &&
110+
"invalid position");
111+
return pos < cstr.getNumDimVars()
112+
? builder.getAffineDimExpr(pos)
113+
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
114+
};
115+
116+
// If the value/dim is already mapped, return the corresponding expression
117+
// directly.
118+
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
119+
if (valueDimToPosition.contains(valueDim))
120+
return getPosExpr(getPos(value, dim));
121+
108122
auto shapedType = dyn_cast<ShapedType>(value.getType());
109123
if (shapedType) {
110-
// Static dimension: return constant directly.
111-
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
124+
// Static dimension: add EQ bound and return expression without pushing the
125+
// dim onto the worklist.
126+
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) {
127+
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
128+
bound(value)[*dim] == shapedType.getDimSize(*dim);
112129
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
130+
}
113131
} else {
114-
// Constant index value: return directly.
115-
if (auto constInt = ::getConstantIntValue(value))
132+
// Constant index value: add EQ bound and return expression without pushing
133+
// the value onto the worklist.
134+
if (auto constInt = ::getConstantIntValue(value)) {
135+
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
136+
bound(value) == *constInt;
116137
return builder.getAffineConstantExpr(*constInt);
138+
}
117139
}
118140

119-
// Dynamic value: add to constraint set.
120-
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
121-
if (!valueDimToPosition.contains(valueDim))
122-
(void)insert(value, dim);
123-
int64_t pos = getPos(value, dim);
124-
return pos < cstr.getNumDimVars()
125-
? builder.getAffineDimExpr(pos)
126-
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
141+
// Dynamic value/dim: add to worklist.
142+
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
127143
}
128144

129145
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -140,7 +156,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
140156

141157
int64_t ValueBoundsConstraintSet::insert(Value value,
142158
std::optional<int64_t> dim,
143-
bool isSymbol) {
159+
bool isSymbol, bool addToWorklist) {
144160
#ifndef NDEBUG
145161
assertValidValueDim(value, dim);
146162
#endif // NDEBUG
@@ -155,7 +171,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
155171
if (positionToValueDim[i].has_value())
156172
valueDimToPosition[*positionToValueDim[i]] = i;
157173

158-
worklist.push(pos);
174+
if (addToWorklist) {
175+
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
176+
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
177+
worklist.push(pos);
178+
}
179+
159180
return pos;
160181
}
161182

@@ -191,7 +212,8 @@ static Operation *getOwnerOfValue(Value value) {
191212
return value.getDefiningOp();
192213
}
193214

194-
void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
215+
void ValueBoundsConstraintSet::processWorklist() {
216+
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
195217
while (!worklist.empty()) {
196218
int64_t pos = worklist.front();
197219
worklist.pop();
@@ -212,20 +234,29 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
212234

213235
// Do not process any further if the stop condition is met.
214236
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
215-
if (stopCondition(value, maybeDim))
237+
if (currentStopCondition(value, maybeDim)) {
238+
LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
239+
<< " (dim: " << maybeDim << ")\n");
216240
continue;
241+
}
217242

218243
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
219244
// the worklist.
220245
auto valueBoundsOp =
221246
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
247+
LLVM_DEBUG(llvm::dbgs()
248+
<< "Query value bounds for: " << value
249+
<< " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
222250
if (valueBoundsOp) {
223251
if (dim == kIndexValue) {
224252
valueBoundsOp.populateBoundsForIndexValue(value, *this);
225253
} else {
226254
valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
227255
}
228256
continue;
257+
} else {
258+
LLVM_DEBUG(llvm::dbgs()
259+
<< "--> ValueBoundsOpInterface not implemented\n");
229260
}
230261

231262
// If the op does not implement `ValueBoundsOpInterface`, check if it
@@ -301,7 +332,8 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
301332
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
302333
ValueBoundsConstraintSet cstr(value.getContext());
303334
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
304-
cstr.processWorklist(stopCondition);
335+
cstr.currentStopCondition = stopCondition;
336+
cstr.processWorklist();
305337

306338
// Project out all variables (apart from `valueDim`) that do not match the
307339
// stop condition.
@@ -494,14 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
494526
// Process the backward slice of `operands` (i.e., reverse use-def chain)
495527
// until `stopCondition` is met.
496528
if (stopCondition) {
497-
cstr.processWorklist(stopCondition);
529+
cstr.currentStopCondition = stopCondition;
530+
cstr.processWorklist();
498531
} else {
499532
// No stop condition specified: Keep adding constraints until a bound could
500533
// be computed.
501-
cstr.processWorklist(
502-
/*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
503-
return cstr.cstr.getConstantBound64(type, pos).has_value();
504-
});
534+
auto stopCondFn = [&](Value v, std::optional<int64_t> dim) {
535+
return cstr.cstr.getConstantBound64(type, pos).has_value();
536+
};
537+
cstr.currentStopCondition = stopCondFn;
538+
cstr.processWorklist();
505539
}
506540

507541
// Compute constant bound for `valueDim`.
@@ -538,6 +572,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
538572
{{value1, dim1}, {value2, dim2}});
539573
}
540574

575+
void ValueBoundsConstraintSet::populateConstraints(Value value,
576+
std::optional<int64_t> dim) {
577+
// `getExpr` pushes the value/dim onto the worklist (unless it was already
578+
// analyzed).
579+
(void)getExpr(value, dim);
580+
// Process all values/dims on the worklist. This may traverse and analyze
581+
// additional IR, depending the current stop function.
582+
processWorklist();
583+
}
584+
585+
bool ValueBoundsConstraintSet::compare(Value value1,
586+
std::optional<int64_t> dim1,
587+
ComparisonOperator cmp, Value value2,
588+
std::optional<int64_t> dim2) {
589+
// This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
590+
// hold.
591+
//
592+
// Example for ComparisonOperator::LE and index-typed values: We would like to
593+
// prove that value1 <= value2. Proof by contradiction: add the inverse
594+
// relation (value1 > value2) to the constraint set and check if the resulting
595+
// constraint set is "empty" (i.e. has no solution). In that case,
596+
// value1 > value2 must be incorrect and we can deduce that value1 <= value2
597+
// holds.
598+
599+
// We cannot use prove anything if the constraint set is already empty.
600+
if (cstr.isEmpty()) {
601+
LLVM_DEBUG(
602+
llvm::dbgs()
603+
<< "cannot compare value/dims: constraint system is already empty");
604+
return false;
605+
}
606+
607+
// EQ can be expressed as LE and GE.
608+
if (cmp == EQ)
609+
return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) &&
610+
compare(value1, dim1, ComparisonOperator::GE, value2, dim2);
611+
612+
// Construct inequality. For the above example: value1 > value2.
613+
// `IntegerRelation` inequalities are expressed in the "flattened" form and
614+
// with ">= 0". I.e., value1 - value2 - 1 >= 0.
615+
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
616+
if (cmp == LT || cmp == LE) {
617+
eq[getPos(value1, dim1)]++;
618+
eq[getPos(value2, dim2)]--;
619+
} else if (cmp == GT || cmp == GE) {
620+
eq[getPos(value1, dim1)]--;
621+
eq[getPos(value2, dim2)]++;
622+
} else {
623+
llvm_unreachable("unsupported comparison operator");
624+
}
625+
if (cmp == LE || cmp == GE)
626+
eq[cstr.getNumDimAndSymbolVars()] -= 1;
627+
628+
// Add inequality to the constraint set and check if it made the constraint
629+
// set empty.
630+
int64_t ineqPos = cstr.getNumInequalities();
631+
cstr.addInequality(eq);
632+
bool isEmpty = cstr.isEmpty();
633+
cstr.removeInequality(ineqPos);
634+
return isEmpty;
635+
}
636+
541637
FailureOr<bool>
542638
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
543639
std::optional<int64_t> dim1,

0 commit comments

Comments
 (0)