Skip to content

Commit b4bab14

Browse files
[mlir][SCF] ValueBoundsConstraintSet: Support preliminary support for branches
This commit adds support for `scf.if` to `ValueBoundsConstraintSet`. Example: ``` %0 = scf.if ... -> index { scf.yield %a : index } else { scf.yield %b : index } ``` The following constraints hold for %0: * %0 >= min(%a, %b) * %0 <= max(%a, %b) Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b: * %0 >= %a * %0 <= %b This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
1 parent ad1b2ac commit b4bab14

File tree

5 files changed

+278
-29
lines changed

5 files changed

+278
-29
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ class ValueBoundsConstraintSet
203203
std::optional<int64_t> dim1 = std::nullopt,
204204
std::optional<int64_t> dim2 = std::nullopt);
205205

206+
/// Traverse the IR starting from the given value/dim and populate constraints
207+
/// as long as the stop condition holds. Also process all values/dims that are
208+
/// already on the worklist.
209+
void populateConstraints(Value value, std::optional<int64_t> dim);
210+
211+
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
212+
enum ComparisonOperator { LT, LE, EQ, GT, GE };
213+
214+
/// Try to prove that, based on the current state of this constraint set
215+
/// (i.e., without analyzing additional IR or adding new constraints), the
216+
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
217+
///
218+
/// Return "true" if the specified relation between the two values/dims was
219+
/// proven to hold. Return "false" if the specified relation could not be
220+
/// proven. This could be because the specified relation does in fact not hold
221+
/// or because there is not enough information in the constraint set. In other
222+
/// words, if we do not know for sure, this function returns "false".
223+
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
224+
Value rhs, std::optional<int64_t> rhsDim);
225+
206226
/// Compute whether the given values/dimensions are equal. Return "failure" if
207227
/// equality could not be determined.
208228
///
@@ -270,13 +290,13 @@ class ValueBoundsConstraintSet
270290

271291
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
272292

273-
/// Populates the constraint set for a value/map without actually computing
274-
/// the bound. Returns the position for the value/map (via the return value
275-
/// and `posOut` output parameter).
276-
int64_t populateConstraintsSet(Value value,
277-
std::optional<int64_t> dim = std::nullopt);
278-
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
279-
int64_t *posOut = nullptr);
293+
/// Given an affine map with a single result (and map operands), add a new
294+
/// column to the constraint set that represents the result of the map.
295+
/// Traverse additional IR starting from the map operands as needed (as long
296+
/// as the stop condition is not satisfied). Also process all values/dims that
297+
/// are already on the worklist. Return the position of the newly added
298+
/// column.
299+
int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);
280300

281301
/// Iteratively process all elements on the worklist until an index-typed
282302
/// value or shaped value meets `stopCondition`. Such values are not processed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,66 @@ struct ForOpInterface
111111
}
112112
};
113113

114+
struct IfOpInterface
115+
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
116+
117+
static void populateBounds(scf::IfOp ifOp, Value value,
118+
std::optional<int64_t> dim,
119+
ValueBoundsConstraintSet &cstr) {
120+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
121+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
122+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
123+
124+
// Populate constraints for the yielded value (and all values on the
125+
// backward slice, as long as the current stop condition is not satisfied).
126+
cstr.populateConstraints(thenValue, dim);
127+
cstr.populateConstraints(elseValue, dim);
128+
auto boundsBuilder = cstr.bound(value);
129+
if (dim)
130+
boundsBuilder[*dim];
131+
132+
// Compare yielded values.
133+
// If thenValue <= elseValue:
134+
// * result <= elseValue
135+
// * result >= thenValue
136+
if (cstr.compare(thenValue, dim,
137+
ValueBoundsConstraintSet::ComparisonOperator::LE,
138+
elseValue, dim)) {
139+
if (dim) {
140+
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
141+
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
142+
} else {
143+
cstr.bound(value) >= thenValue;
144+
cstr.bound(value) <= elseValue;
145+
}
146+
}
147+
// If elseValue <= thenValue:
148+
// * result <= thenValue
149+
// * result >= elseValue
150+
if (cstr.compare(elseValue, dim,
151+
ValueBoundsConstraintSet::ComparisonOperator::LE,
152+
thenValue, dim)) {
153+
if (dim) {
154+
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
155+
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
156+
} else {
157+
cstr.bound(value) >= elseValue;
158+
cstr.bound(value) <= thenValue;
159+
}
160+
}
161+
}
162+
163+
void populateBoundsForIndexValue(Operation *op, Value value,
164+
ValueBoundsConstraintSet &cstr) const {
165+
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
166+
}
167+
168+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
169+
ValueBoundsConstraintSet &cstr) const {
170+
populateBounds(cast<IfOp>(op), value, dim, cstr);
171+
}
172+
};
173+
114174
} // namespace
115175
} // namespace scf
116176
} // namespace mlir
@@ -119,5 +179,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
119179
DialectRegistry &registry) {
120180
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
121181
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
182+
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
122183
});
123184
}

mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,24 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
5959
ScalableValueBoundsConstraintSet scalableCstr(
6060
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
6161
vscaleMin, vscaleMax);
62-
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
62+
int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
63+
scalableCstr.processWorklist();
6364

64-
// Project out all variables apart from vscale.
65-
// This should result in constraints in terms of vscale only.
65+
// Project out all columns apart from vscale and the starting point
66+
// (value/dim). This should result in constraints in terms of vscale only.
6667
auto projectOutFn = [&](ValueDim p) {
67-
return p.first != scalableCstr.getVscaleValue();
68+
bool isStartingPoint =
69+
p.first == value &&
70+
p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
71+
return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
6872
};
6973
scalableCstr.projectOut(projectOutFn);
7074

7175
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
7276
scalableCstr.positionToValueDim.size() &&
7377
"inconsistent mapping state");
7478

75-
// Check that the only symbols left are vscale.
79+
// Check that the only columns left are vscale and the starting point.
7680
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
7781
if (i == pos)
7882
continue;

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -529,15 +529,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
529529

530530
// Default stop condition if none was specified: Keep adding constraints until
531531
// a bound could be computed.
532-
int64_t pos;
532+
int64_t pos = 0;
533533
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
534534
ValueBoundsConstraintSet &cstr) {
535535
return cstr.cstr.getConstantBound64(type, pos).has_value();
536536
};
537537

538538
ValueBoundsConstraintSet cstr(
539539
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
540-
cstr.populateConstraintsSet(map, operands, &pos);
540+
pos = cstr.populateConstraints(map, operands);
541+
assert(pos == 0 && "expected `map` is the first column");
541542

542543
// Compute constant bound for `valueDim`.
543544
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -546,29 +547,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
546547
return failure();
547548
}
548549

549-
int64_t
550-
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
551-
std::optional<int64_t> dim) {
550+
void ValueBoundsConstraintSet::populateConstraints(Value value,
551+
std::optional<int64_t> dim) {
552552
#ifndef NDEBUG
553553
assertValidValueDim(value, dim);
554554
#endif // NDEBUG
555555

556-
AffineMap map =
557-
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
558-
Builder(value.getContext()).getAffineDimExpr(0));
559-
return populateConstraintsSet(map, {{value, dim}});
556+
// `getExpr` pushes the value/dim onto the worklist (unless it was already
557+
// analyzed).
558+
(void)getExpr(value, dim);
559+
// Process all values/dims on the worklist. This may traverse and analyze
560+
// additional IR, depending the current stop function.
561+
processWorklist();
560562
}
561563

562-
int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
563-
ValueDimList operands,
564-
int64_t *posOut) {
564+
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
565+
ValueDimList operands) {
565566
assert(map.getNumResults() == 1 && "expected affine map with one result");
566567
int64_t pos = insert(/*isSymbol=*/false);
567-
if (posOut)
568-
*posOut = pos;
569568

570569
// Add map and operands to the constraint set. Dimensions are converted to
571-
// symbols. All operands are added to the worklist.
570+
// symbols. All operands are added to the worklist (unless they were already
571+
// processed).
572572
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
573573
return getExpr(v.first, v.second);
574574
};
@@ -603,6 +603,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
603603
{{value1, dim1}, {value2, dim2}});
604604
}
605605

606+
bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
607+
ComparisonOperator cmp, Value rhs,
608+
std::optional<int64_t> rhsDim) {
609+
// This function returns "true" if "lhs CMP rhs" is proven to hold.
610+
//
611+
// Example for ComparisonOperator::LE and index-typed values: We would like to
612+
// prove that lhs <= rhs. Proof by contradiction: add the inverse
613+
// relation (lhs > rhs) to the constraint set and check if the resulting
614+
// constraint set is "empty" (i.e. has no solution). In that case,
615+
// lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
616+
617+
// We cannot prove anything if the constraint set is already empty.
618+
if (cstr.isEmpty()) {
619+
LLVM_DEBUG(
620+
llvm::dbgs()
621+
<< "cannot compare value/dims: constraint system is already empty");
622+
return false;
623+
}
624+
625+
// EQ can be expressed as LE and GE.
626+
if (cmp == EQ)
627+
return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
628+
compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
629+
630+
// Construct inequality. For the above example: lhs > rhs.
631+
// `IntegerRelation` inequalities are expressed in the "flattened" form and
632+
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
633+
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
634+
if (cmp == LT || cmp == LE) {
635+
++eq[getPos(lhs, lhsDim)];
636+
--eq[getPos(rhs, rhsDim)];
637+
} else if (cmp == GT || cmp == GE) {
638+
--eq[getPos(lhs, lhsDim)];
639+
++eq[getPos(rhs, rhsDim)];
640+
} else {
641+
llvm_unreachable("unsupported comparison operator");
642+
}
643+
if (cmp == LE || cmp == GE)
644+
eq[cstr.getNumDimAndSymbolVars()] -= 1;
645+
646+
// Add inequality to the constraint set and check if it made the constraint
647+
// set empty.
648+
int64_t ineqPos = cstr.getNumInequalities();
649+
cstr.addInequality(eq);
650+
bool isEmpty = cstr.isEmpty();
651+
cstr.removeInequality(ineqPos);
652+
return isEmpty;
653+
}
654+
606655
FailureOr<bool>
607656
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
608657
std::optional<int64_t> dim1,

mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
2-
// RUN: -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
2+
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
33

44
// CHECK-LABEL: func @scf_for(
55
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
104104
"test.some_use"(%reify1) : (index) -> ()
105105
return
106106
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: func @scf_if_constant(
111+
func.func @scf_if_constant(%c : i1) {
112+
// CHECK: arith.constant 4 : index
113+
// CHECK: arith.constant 9 : index
114+
%c4 = arith.constant 4 : index
115+
%c9 = arith.constant 9 : index
116+
%r = scf.if %c -> index {
117+
scf.yield %c4 : index
118+
} else {
119+
scf.yield %c9 : index
120+
}
121+
122+
// CHECK: %[[c4:.*]] = arith.constant 4 : index
123+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
124+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
125+
%reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
126+
// CHECK: "test.some_use"(%[[c4]], %[[c10]])
127+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
128+
return
129+
}
130+
131+
// -----
132+
133+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
134+
// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
135+
// CHECK-LABEL: func @scf_if_dynamic(
136+
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
137+
func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
138+
%c4 = arith.constant 4 : index
139+
%r = scf.if %c -> index {
140+
%add1 = arith.addi %a, %b : index
141+
scf.yield %add1 : index
142+
} else {
143+
%add2 = arith.addi %b, %c4 : index
144+
%add3 = arith.addi %add2, %a : index
145+
scf.yield %add3 : index
146+
}
147+
148+
// CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
149+
// CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
150+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
151+
%reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
152+
// CHECK: "test.some_use"(%[[lb]], %[[ub]])
153+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
154+
return
155+
}
156+
157+
// -----
158+
159+
func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
160+
%r = scf.if %c -> index {
161+
scf.yield %a : index
162+
} else {
163+
scf.yield %b : index
164+
}
165+
// The reified bound would be min(%a, %b). min/max expressions are not
166+
// supported in reified bounds.
167+
// expected-error @below{{could not reify bound}}
168+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
169+
"test.some_use"(%reify1) : (index) -> ()
170+
return
171+
}
172+
173+
// -----
174+
175+
// CHECK-LABEL: func @scf_if_tensor_dim(
176+
func.func @scf_if_tensor_dim(%c : i1) {
177+
// CHECK: arith.constant 4 : index
178+
// CHECK: arith.constant 9 : index
179+
%c4 = arith.constant 4 : index
180+
%c9 = arith.constant 9 : index
181+
%t1 = tensor.empty(%c4) : tensor<?xf32>
182+
%t2 = tensor.empty(%c9) : tensor<?xf32>
183+
%r = scf.if %c -> tensor<?xf32> {
184+
scf.yield %t1 : tensor<?xf32>
185+
} else {
186+
scf.yield %t2 : tensor<?xf32>
187+
}
188+
189+
// CHECK: %[[c4:.*]] = arith.constant 4 : index
190+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
191+
%reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
192+
: (tensor<?xf32>) -> (index)
193+
%reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
194+
: (tensor<?xf32>) -> (index)
195+
// CHECK: "test.some_use"(%[[c4]], %[[c10]])
196+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
197+
return
198+
}
199+
200+
// -----
201+
202+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
203+
// CHECK-LABEL: func @scf_if_eq(
204+
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
205+
func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
206+
%c0 = arith.constant 0 : index
207+
%r = scf.if %c -> index {
208+
%add1 = arith.addi %a, %b : index
209+
scf.yield %add1 : index
210+
} else {
211+
%add2 = arith.addi %b, %c0 : index
212+
%add3 = arith.addi %add2, %a : index
213+
scf.yield %add3 : index
214+
}
215+
216+
// CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
217+
%reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
218+
// CHECK: "test.some_use"(%[[eq]])
219+
"test.some_use"(%reify1) : (index) -> ()
220+
return
221+
}

0 commit comments

Comments
 (0)