Skip to content

Commit 8057ddd

Browse files
[mlir][SCF] ValueBoundsConstraintSet: Support preliminary support for branches
This commit adds support for `scf.if` to `ValueBoundsConstraintSet`. Example: ``` %0 = scf.if ... -> index { scf.yield %a : index } else { scf.yield %b : index } ``` The following constraints hold for %0: * %0 >= min(%a, %b) * %0 <= max(%a, %b) Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b: * %0 >= %a * %0 <= %b This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
1 parent 305001b commit 8057ddd

File tree

4 files changed

+264
-2
lines changed

4 files changed

+264
-2
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,28 @@ class ValueBoundsConstraintSet {
199199
std::optional<int64_t> dim1 = std::nullopt,
200200
std::optional<int64_t> dim2 = std::nullopt);
201201

202+
/// Traverse the IR starting from the given value/dim and add populate
203+
/// constraints as long as the currently set stop condition holds. Also
204+
/// processes all values/dims that are already on the worklist.
205+
void populateConstraints(Value value, std::optional<int64_t> dim);
206+
207+
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
208+
enum ComparisonOperator { LT, LE, EQ, GT, GE };
209+
210+
/// Try to prove that, based on the current state of this constraint set
211+
/// (i.e., without analyzing additional IR or adding new constraints), it can
212+
/// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the
213+
/// second given value/dim.
214+
///
215+
/// Return "true" if the specified relation between the two values/dims was
216+
/// proven to hold. Return "false" if the specified relation could not be
217+
/// proven. This could be because the specified relation does in fact not hold
218+
/// or because there is not enough information in the constraint set. In other
219+
/// words, if we do not know for sure, this function returns "false".
220+
bool compare(Value value1, std::optional<int64_t> dim1,
221+
ComparisonOperator cmp, Value value2,
222+
std::optional<int64_t> dim2);
223+
202224
/// Compute whether the given values/dimensions are equal. Return "failure" if
203225
/// equality could not be determined.
204226
///

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct ForOpInterface
111111
}
112112
};
113113

114+
struct IfOpInterface
115+
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
116+
117+
void populateBoundsForIndexValue(Operation *op, Value value,
118+
ValueBoundsConstraintSet &cstr) const {
119+
auto ifOp = cast<IfOp>(op);
120+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
121+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
122+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
123+
124+
// Populate constraints for the yielded value (and all values on the
125+
// backward slice, as long as the current stop condition is not satisfied).
126+
cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt);
127+
cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt);
128+
129+
// Compare yielded values.
130+
// If thenValue <= elseValue:
131+
// * result <= elseValue
132+
// * result >= thenValue
133+
if (cstr.compare(thenValue, /*dim1=*/std::nullopt,
134+
ValueBoundsConstraintSet::ComparisonOperator::LE,
135+
elseValue, /*dim2=*/std::nullopt)) {
136+
cstr.bound(value) >= thenValue;
137+
cstr.bound(value) <= elseValue;
138+
}
139+
// If elseValue <= thenValue:
140+
// * result <= thenValue
141+
// * result >= elseValue
142+
if (cstr.compare(elseValue, /*dim1=*/std::nullopt,
143+
ValueBoundsConstraintSet::ComparisonOperator::LE,
144+
thenValue, /*dim2=*/std::nullopt)) {
145+
cstr.bound(value) >= elseValue;
146+
cstr.bound(value) <= thenValue;
147+
}
148+
}
149+
150+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
151+
ValueBoundsConstraintSet &cstr) const {
152+
// See `populateBoundsForIndexValue` for documentation.
153+
auto ifOp = cast<IfOp>(op);
154+
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
155+
Value thenValue = ifOp.thenYield().getResults()[resultNum];
156+
Value elseValue = ifOp.elseYield().getResults()[resultNum];
157+
158+
cstr.populateConstraints(thenValue, dim);
159+
cstr.populateConstraints(elseValue, dim);
160+
161+
if (cstr.compare(thenValue, dim,
162+
ValueBoundsConstraintSet::ComparisonOperator::LE,
163+
elseValue, dim)) {
164+
cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim);
165+
cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim);
166+
}
167+
if (cstr.compare(elseValue, dim,
168+
ValueBoundsConstraintSet::ComparisonOperator::LE,
169+
thenValue, dim)) {
170+
cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim);
171+
cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim);
172+
}
173+
}
174+
};
175+
114176
} // namespace
115177
} // namespace scf
116178
} // namespace mlir
@@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
119181
DialectRegistry &registry) {
120182
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
121183
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
184+
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
122185
});
123186
}

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
575575
{{value1, dim1}, {value2, dim2}});
576576
}
577577

578+
void ValueBoundsConstraintSet::populateConstraints(Value value,
579+
std::optional<int64_t> dim) {
580+
// `getExpr` pushes the value/dim onto the worklist (unless it was already
581+
// analyzed).
582+
(void)getExpr(value, dim);
583+
// Process all values/dims on the worklist. This may traverse and analyze
584+
// additional IR, depending the current stop function.
585+
processWorklist();
586+
}
587+
588+
bool ValueBoundsConstraintSet::compare(Value value1,
589+
std::optional<int64_t> dim1,
590+
ComparisonOperator cmp, Value value2,
591+
std::optional<int64_t> dim2) {
592+
// This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
593+
// hold.
594+
//
595+
// Example for ComparisonOperator::LE and index-typed values: We would like to
596+
// prove that value1 <= value2. Proof by contradiction: add the inverse
597+
// relation (value1 > value2) to the constraint set and check if the resulting
598+
// constraint set is "empty" (i.e. has no solution). In that case,
599+
// value1 > value2 must be incorrect and we can deduce that value1 <= value2
600+
// holds.
601+
602+
// We cannot use prove anything if the constraint set is already empty.
603+
if (cstr.isEmpty()) {
604+
LLVM_DEBUG(
605+
llvm::dbgs()
606+
<< "cannot compare value/dims: constraint system is already empty");
607+
return false;
608+
}
609+
610+
// EQ can be expressed as LE and GE.
611+
if (cmp == EQ)
612+
return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) &&
613+
compare(value1, dim1, ComparisonOperator::GE, value2, dim2);
614+
615+
// Construct inequality. For the above example: value1 > value2.
616+
// `IntegerRelation` inequalities are expressed in the "flattened" form and
617+
// with ">= 0". I.e., value1 - value2 - 1 >= 0.
618+
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
619+
if (cmp == LT || cmp == LE) {
620+
eq[getPos(value1, dim1)]++;
621+
eq[getPos(value2, dim2)]--;
622+
} else if (cmp == GT || cmp == GE) {
623+
eq[getPos(value1, dim1)]--;
624+
eq[getPos(value2, dim2)]++;
625+
} else {
626+
llvm_unreachable("unsupported comparison operator");
627+
}
628+
if (cmp == LE || cmp == GE)
629+
eq[cstr.getNumDimAndSymbolVars()] -= 1;
630+
631+
// Add inequality to the constraint set and check if it made the constraint
632+
// set empty.
633+
int64_t ineqPos = cstr.getNumInequalities();
634+
cstr.addInequality(eq);
635+
bool isEmpty = cstr.isEmpty();
636+
cstr.removeInequality(ineqPos);
637+
return isEmpty;
638+
}
639+
578640
FailureOr<bool>
579641
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
580642
std::optional<int64_t> dim1,

mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
2-
// RUN: -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
2+
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
33

44
// CHECK-LABEL: func @scf_for(
55
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
104104
"test.some_use"(%reify1) : (index) -> ()
105105
return
106106
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: func @scf_if_constant(
111+
func.func @scf_if_constant(%c : i1) {
112+
// CHECK: arith.constant 4 : index
113+
// CHECK: arith.constant 9 : index
114+
%c4 = arith.constant 4 : index
115+
%c9 = arith.constant 9 : index
116+
%r = scf.if %c -> index {
117+
scf.yield %c4 : index
118+
} else {
119+
scf.yield %c9 : index
120+
}
121+
122+
// CHECK: %[[c4:.*]] = arith.constant 4 : index
123+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
124+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
125+
%reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
126+
// CHECK: "test.some_use"(%[[c4]], %[[c10]])
127+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
128+
return
129+
}
130+
131+
// -----
132+
133+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
134+
// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
135+
// CHECK-LABEL: func @scf_if_dynamic(
136+
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
137+
func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
138+
%c4 = arith.constant 4 : index
139+
%r = scf.if %c -> index {
140+
%add1 = arith.addi %a, %b : index
141+
scf.yield %add1 : index
142+
} else {
143+
%add2 = arith.addi %b, %c4 : index
144+
%add3 = arith.addi %add2, %a : index
145+
scf.yield %add3 : index
146+
}
147+
148+
// CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
149+
// CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
150+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
151+
%reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
152+
// CHECK: "test.some_use"(%[[lb]], %[[ub]])
153+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
154+
return
155+
}
156+
157+
// -----
158+
159+
func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
160+
%r = scf.if %c -> index {
161+
scf.yield %a : index
162+
} else {
163+
scf.yield %b : index
164+
}
165+
// The reified bound would be min(%a, %b). min/max expressions are not
166+
// supported in reified bounds.
167+
// expected-error @below{{could not reify bound}}
168+
%reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
169+
"test.some_use"(%reify1) : (index) -> ()
170+
return
171+
}
172+
173+
// -----
174+
175+
// CHECK-LABEL: func @scf_if_tensor_dim(
176+
func.func @scf_if_tensor_dim(%c : i1) {
177+
// CHECK: arith.constant 4 : index
178+
// CHECK: arith.constant 9 : index
179+
%c4 = arith.constant 4 : index
180+
%c9 = arith.constant 9 : index
181+
%t1 = tensor.empty(%c4) : tensor<?xf32>
182+
%t2 = tensor.empty(%c9) : tensor<?xf32>
183+
%r = scf.if %c -> tensor<?xf32> {
184+
scf.yield %t1 : tensor<?xf32>
185+
} else {
186+
scf.yield %t2 : tensor<?xf32>
187+
}
188+
189+
// CHECK: %[[c4:.*]] = arith.constant 4 : index
190+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
191+
%reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
192+
: (tensor<?xf32>) -> (index)
193+
%reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
194+
: (tensor<?xf32>) -> (index)
195+
// CHECK: "test.some_use"(%[[c4]], %[[c10]])
196+
"test.some_use"(%reify1, %reify2) : (index, index) -> ()
197+
return
198+
}
199+
200+
// -----
201+
202+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
203+
// CHECK-LABEL: func @scf_if_eq(
204+
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
205+
func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
206+
%c0 = arith.constant 0 : index
207+
%r = scf.if %c -> index {
208+
%add1 = arith.addi %a, %b : index
209+
scf.yield %add1 : index
210+
} else {
211+
%add2 = arith.addi %b, %c0 : index
212+
%add3 = arith.addi %add2, %a : index
213+
scf.yield %add3 : index
214+
}
215+
216+
// CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
217+
%reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
218+
// CHECK: "test.some_use"(%[[eq]])
219+
"test.some_use"(%reify1) : (index) -> ()
220+
return
221+
}

0 commit comments

Comments
 (0)