Skip to content

Commit cec5571

Browse files
[mlir][SCF] ValueBoundsConstraintSet: Support preliminary support for branches
This commit adds support for `scf.if` to `ValueBoundsConstraintSet`. Example: ``` %0 = scf.if ... -> index { scf.yield %a : index } else { scf.yield %b : index } ``` The following constraints hold for %0: * %0 >= min(%a, %b) * %0 <= max(%a, %b) Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b: * %0 >= %a * %0 <= %b This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
1 parent 843cc47 commit cec5571

File tree

5 files changed

+334
-46
lines changed

5 files changed

+334
-46
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ class ValueBoundsConstraintSet
203203
std::optional<int64_t> dim1 = std::nullopt,
204204
std::optional<int64_t> dim2 = std::nullopt);
205205

206+
/// Traverse the IR starting from the given value/dim and populate constraints
207+
/// as long as the stop condition holds. Also process all values/dims that are
208+
/// already on the worklist.
209+
void populateConstraints(Value value, std::optional<int64_t> dim);
210+
211+
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
212+
enum ComparisonOperator { LT, LE, EQ, GT, GE };
213+
214+
/// Try to prove that, based on the current state of this constraint set
215+
/// (i.e., without analyzing additional IR or adding new constraints), the
216+
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
217+
///
218+
/// Return "true" if the specified relation between the two values/dims was
219+
/// proven to hold. Return "false" if the specified relation could not be
220+
/// proven. This could be because the specified relation does in fact not hold
221+
/// or because there is not enough information in the constraint set. In other
222+
/// words, if we do not know for sure, this function returns "false".
223+
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
224+
Value rhs, std::optional<int64_t> rhsDim);
225+
206226
/// Compute whether the given values/dimensions are equal. Return "failure" if
207227
/// equality could not be determined.
208228
///
@@ -274,13 +294,13 @@ class ValueBoundsConstraintSet
274294

275295
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
276296

277-
/// Populates the constraint set for a value/map without actually computing
278-
/// the bound. Returns the position for the value/map (via the return value
279-
/// and `posOut` output parameter).
280-
int64_t populateConstraintsSet(Value value,
281-
std::optional<int64_t> dim = std::nullopt);
282-
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
283-
int64_t *posOut = nullptr);
297+
/// Given an affine map with a single result (and map operands), add a new
298+
/// column to the constraint set that represents the result of the map.
299+
/// Traverse additional IR starting from the map operands as needed (as long
300+
/// as the stop condition is not satisfied). Also process all values/dims that
301+
/// are already on the worklist. Return the position of the newly added
302+
/// column.
303+
int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);
284304

285305
/// Iteratively process all elements on the worklist until an index-typed
286306
/// value or shaped value meets `stopCondition`. Such values are not processed
@@ -295,14 +315,19 @@ class ValueBoundsConstraintSet
295315
/// value/dimension exists in the constraint set.
296316
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
297317

318+
/// Return an affine expression that represents column `pos` in the constraint
319+
/// set.
320+
AffineExpr getPosExpr(int64_t pos);
321+
298322
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
299323
/// "false", a dimension is added. The value/dimension is added to the
300-
/// worklist.
324+
/// worklist if `addToWorklist` is set.
301325
///
302326
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
303327
/// cannot be multiplied. Furthermore, bounds can only be queried for
304328
/// dimensions but not for symbols.
305-
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
329+
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
330+
bool addToWorklist = true);
306331

307332
/// Insert an anonymous column into the constraint set. The column is not
308333
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,66 @@ struct ForOpInterface
111111
}
112112
};
113113

114+
struct IfOpInterface
115+
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
116+
117+
static void populateBounds(scf::IfOp ifOp, Value value,
118+
std::optional<int64_t> dim,
119+
ValueBoundsConstraintSet &cstr) {
120+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
121+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
122+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
123+
124+
// Populate constraints for the yielded value (and all values on the
125+
// backward slice, as long as the current stop condition is not satisfied).
126+
cstr.populateConstraints(thenValue, dim);
127+
cstr.populateConstraints(elseValue, dim);
128+
auto boundsBuilder = cstr.bound(value);
129+
if (dim)
130+
boundsBuilder[*dim];
131+
132+
// Compare yielded values.
133+
// If thenValue <= elseValue:
134+
// * result <= elseValue
135+
// * result >= thenValue
136+
if (cstr.compare(thenValue, dim,
137+
ValueBoundsConstraintSet::ComparisonOperator::LE,
138+
elseValue, dim)) {
139+
if (dim) {
140+
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
141+
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
142+
} else {
143+
cstr.bound(value) >= thenValue;
144+
cstr.bound(value) <= elseValue;
145+
}
146+
}
147+
// If elseValue <= thenValue:
148+
// * result <= thenValue
149+
// * result >= elseValue
150+
if (cstr.compare(elseValue, dim,
151+
ValueBoundsConstraintSet::ComparisonOperator::LE,
152+
thenValue, dim)) {
153+
if (dim) {
154+
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
155+
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
156+
} else {
157+
cstr.bound(value) >= elseValue;
158+
cstr.bound(value) <= thenValue;
159+
}
160+
}
161+
}
162+
163+
void populateBoundsForIndexValue(Operation *op, Value value,
164+
ValueBoundsConstraintSet &cstr) const {
165+
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
166+
}
167+
168+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
169+
ValueBoundsConstraintSet &cstr) const {
170+
populateBounds(cast<IfOp>(op), value, dim, cstr);
171+
}
172+
};
173+
114174
} // namespace
115175
} // namespace scf
116176
} // namespace mlir
@@ -119,5 +179,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
119179
DialectRegistry &registry) {
120180
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
121181
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
182+
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
122183
});
123184
}

mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,24 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
5959
ScalableValueBoundsConstraintSet scalableCstr(
6060
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
6161
vscaleMin, vscaleMax);
62-
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
62+
int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
63+
scalableCstr.processWorklist();
6364

64-
// Project out all variables apart from vscale.
65-
// This should result in constraints in terms of vscale only.
65+
// Project out all columns apart from vscale and the starting point
66+
// (value/dim). This should result in constraints in terms of vscale only.
6667
auto projectOutFn = [&](ValueDim p) {
67-
return p.first != scalableCstr.getVscaleValue();
68+
bool isStartingPoint =
69+
p.first == value &&
70+
p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
71+
return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
6872
};
6973
scalableCstr.projectOut(projectOutFn);
7074

7175
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
7276
scalableCstr.positionToValueDim.size() &&
7377
"inconsistent mapping state");
7478

75-
// Check that the only symbols left are vscale.
79+
// Check that the only columns left are vscale and the starting point.
7680
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
7781
if (i == pos)
7882
continue;

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 113 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
110110
assertValidValueDim(value, dim);
111111
#endif // NDEBUG
112112

113+
// Check if the value/dim is statically known. In that case, an affine
114+
// constant expression should be returned. This allows us to support
115+
// multiplications with constants. (Multiplications of two columns in the
116+
// constraint set is not supported.)
117+
std::optional<int64_t> constSize = std::nullopt;
113118
auto shapedType = dyn_cast<ShapedType>(value.getType());
114119
if (shapedType) {
115-
// Static dimension: return constant directly.
116120
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
117-
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
118-
} else {
119-
// Constant index value: return directly.
120-
if (auto constInt = ::getConstantIntValue(value))
121-
return builder.getAffineConstantExpr(*constInt);
121+
constSize = shapedType.getDimSize(*dim);
122+
} else if (auto constInt = ::getConstantIntValue(value)) {
123+
constSize = *constInt;
122124
}
123125

124-
// Dynamic value: add to constraint set.
126+
// If the value/dim is already mapped, return the corresponding expression
127+
// directly.
125128
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
126-
if (!valueDimToPosition.contains(valueDim))
127-
(void)insert(value, dim);
128-
int64_t pos = getPos(value, dim);
129-
return pos < cstr.getNumDimVars()
130-
? builder.getAffineDimExpr(pos)
131-
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
129+
if (valueDimToPosition.contains(valueDim)) {
130+
// If it is a constant, return an affine constant expression. Otherwise,
131+
// return an affine expression that represents the respective column in the
132+
// constraint set.
133+
if (constSize)
134+
return builder.getAffineConstantExpr(*constSize);
135+
return getPosExpr(getPos(value, dim));
136+
}
137+
138+
if (constSize) {
139+
// Constant index value/dim: add column to the constraint set, add EQ bound
140+
// and return an affine constant expression without pushing the newly added
141+
// column to the worklist.
142+
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
143+
if (shapedType)
144+
bound(value)[*dim] == *constSize;
145+
else
146+
bound(value) == *constSize;
147+
return builder.getAffineConstantExpr(*constSize);
148+
}
149+
150+
// Dynamic value/dim: insert column to the constraint set and put it on the
151+
// worklist. Return an affine expression that represents the newly inserted
152+
// column in the constraint set.
153+
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
132154
}
133155

134156
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
145167

146168
int64_t ValueBoundsConstraintSet::insert(Value value,
147169
std::optional<int64_t> dim,
148-
bool isSymbol) {
170+
bool isSymbol, bool addToWorklist) {
149171
#ifndef NDEBUG
150172
assertValidValueDim(value, dim);
151173
#endif // NDEBUG
@@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
160182
if (positionToValueDim[i].has_value())
161183
valueDimToPosition[*positionToValueDim[i]] = i;
162184

163-
worklist.push(pos);
185+
if (addToWorklist) {
186+
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
187+
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
188+
worklist.push(pos);
189+
}
190+
164191
return pos;
165192
}
166193

@@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
190217
return it->second;
191218
}
192219

220+
AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
221+
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
222+
return pos < cstr.getNumDimVars()
223+
? builder.getAffineDimExpr(pos)
224+
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
225+
}
226+
193227
static Operation *getOwnerOfValue(Value value) {
194228
if (auto bbArg = dyn_cast<BlockArgument>(value))
195229
return bbArg.getOwner()->getParentOp();
@@ -492,15 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
492526

493527
// Default stop condition if none was specified: Keep adding constraints until
494528
// a bound could be computed.
495-
int64_t pos;
529+
int64_t pos = 0;
496530
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
497531
ValueBoundsConstraintSet &cstr) {
498532
return cstr.cstr.getConstantBound64(type, pos).has_value();
499533
};
500534

501535
ValueBoundsConstraintSet cstr(
502536
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
503-
cstr.populateConstraintsSet(map, operands, &pos);
537+
pos = cstr.populateConstraints(map, operands);
538+
assert(pos == 0 && "expected `map` is the first column");
504539

505540
// Compute constant bound for `valueDim`.
506541
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -509,29 +544,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
509544
return failure();
510545
}
511546

512-
int64_t
513-
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
514-
std::optional<int64_t> dim) {
547+
void ValueBoundsConstraintSet::populateConstraints(Value value,
548+
std::optional<int64_t> dim) {
515549
#ifndef NDEBUG
516550
assertValidValueDim(value, dim);
517551
#endif // NDEBUG
518552

519-
AffineMap map =
520-
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
521-
Builder(value.getContext()).getAffineDimExpr(0));
522-
return populateConstraintsSet(map, {{value, dim}});
553+
// `getExpr` pushes the value/dim onto the worklist (unless it was already
554+
// analyzed).
555+
(void)getExpr(value, dim);
556+
// Process all values/dims on the worklist. This may traverse and analyze
557+
// additional IR, depending the current stop function.
558+
processWorklist();
523559
}
524560

525-
int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
526-
ValueDimList operands,
527-
int64_t *posOut) {
561+
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
562+
ValueDimList operands) {
528563
assert(map.getNumResults() == 1 && "expected affine map with one result");
529564
int64_t pos = insert(/*isSymbol=*/false);
530-
if (posOut)
531-
*posOut = pos;
532565

533566
// Add map and operands to the constraint set. Dimensions are converted to
534-
// symbols. All operands are added to the worklist.
567+
// symbols. All operands are added to the worklist (unless they were already
568+
// processed).
535569
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
536570
return getExpr(v.first, v.second);
537571
};
@@ -566,6 +600,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
566600
{{value1, dim1}, {value2, dim2}});
567601
}
568602

603+
bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
604+
ComparisonOperator cmp, Value rhs,
605+
std::optional<int64_t> rhsDim) {
606+
// This function returns "true" if "lhs CMP rhs" is proven to hold.
607+
//
608+
// Example for ComparisonOperator::LE and index-typed values: We would like to
609+
// prove that lhs <= rhs. Proof by contradiction: add the inverse
610+
// relation (lhs > rhs) to the constraint set and check if the resulting
611+
// constraint set is "empty" (i.e. has no solution). In that case,
612+
// lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
613+
614+
// We cannot prove anything if the constraint set is already empty.
615+
if (cstr.isEmpty()) {
616+
LLVM_DEBUG(
617+
llvm::dbgs()
618+
<< "cannot compare value/dims: constraint system is already empty");
619+
return false;
620+
}
621+
622+
// EQ can be expressed as LE and GE.
623+
if (cmp == EQ)
624+
return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
625+
compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
626+
627+
// Construct inequality. For the above example: lhs > rhs.
628+
// `IntegerRelation` inequalities are expressed in the "flattened" form and
629+
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
630+
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
631+
if (cmp == LT || cmp == LE) {
632+
++eq[getPos(lhs, lhsDim)];
633+
--eq[getPos(rhs, rhsDim)];
634+
} else if (cmp == GT || cmp == GE) {
635+
--eq[getPos(lhs, lhsDim)];
636+
++eq[getPos(rhs, rhsDim)];
637+
} else {
638+
llvm_unreachable("unsupported comparison operator");
639+
}
640+
if (cmp == LE || cmp == GE)
641+
eq[cstr.getNumDimAndSymbolVars()] -= 1;
642+
643+
// Add inequality to the constraint set and check if it made the constraint
644+
// set empty.
645+
int64_t ineqPos = cstr.getNumInequalities();
646+
cstr.addInequality(eq);
647+
bool isEmpty = cstr.isEmpty();
648+
cstr.removeInequality(ineqPos);
649+
return isEmpty;
650+
}
651+
569652
FailureOr<bool>
570653
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
571654
std::optional<int64_t> dim1,

0 commit comments

Comments
 (0)