-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Arith] ValueBoundsOpInterface
: Support arith.select
#86383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Arith] ValueBoundsOpInterface
: Support arith.select
#86383
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Matthias Springer (matthias-springer) ChangesThis commit adds a Full diff: https://github.com/llvm/llvm-project/pull/86383.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 9c6b50e767ea26..bb7b9c939fcb09 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,6 +66,75 @@ struct MulIOpInterface
}
};
+struct SelectOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+
+ static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ Value value = selectOp.getResult();
+ Value condition = selectOp.getCondition();
+ Value trueValue = selectOp.getTrueValue();
+ Value falseValue = selectOp.getFalseValue();
+
+ if (isa<ShapedType>(condition.getType())) {
+ // If the condition is a shaped type, the condition is applied
+ // element-wise. All three operands must have the same shape.
+ cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+ return;
+ }
+
+ // Populate constraints for the true/false values (and all values on the
+ // backward slice, as long as the current stop condition is not satisfied).
+ cstr.populateConstraints(trueValue, dim);
+ cstr.populateConstraints(falseValue, dim);
+ auto boundsBuilder = cstr.bound(value);
+ if (dim)
+ boundsBuilder[*dim];
+
+ // Compare yielded values.
+ // If trueValue <= falseValue:
+ // * result <= falseValue
+ // * result >= trueValue
+ if (cstr.compare(trueValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ falseValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+ } else {
+ cstr.bound(value) >= trueValue;
+ cstr.bound(value) <= falseValue;
+ }
+ }
+ // If falseValue <= trueValue:
+ // * result <= trueValue
+ // * result >= falseValue
+ if (cstr.compare(falseValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ trueValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+ } else {
+ cstr.bound(value) >= falseValue;
+ cstr.bound(value) <= trueValue;
+ }
+ }
+ }
+
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+ }
+
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), dim, cstr);
+ }
+};
} // namespace
} // namespace arith
} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
%0 = "test.reify_bound"(%c5) : (index) -> (index)
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: func @arith_select(
+func.func @arith_select(%c: i1) -> (index, index) {
+ // CHECK: arith.constant 5 : index
+ %c5 = arith.constant 5 : index
+ // CHECK: arith.constant 9 : index
+ %c9 = arith.constant 9 : index
+ %r = arith.select %c, %c5, %c9 : index
+ // CHECK: %[[c5:.*]] = arith.constant 5 : index
+ // CHECK: %[[c10:.*]] = arith.constant 10 : index
+ %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+ %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+ // CHECK: return %[[c5]], %[[c10]]
+ return %0, %1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_select_elementwise(
+// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
+func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
+ %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
+ %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
+ : (tensor<?xf32>) -> (index)
+ // CHECK: return %[[dim]]
+ return %0 : index
+}
|
3c4adb5
to
f8f4249
Compare
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
680d04d
to
4ef0c78
Compare
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand. Note: This is a re-upload of #86383.
This commit adds a
ValueBoundsOpInterface
implementation forarith.select
. The implementation is almost identical toscf.if
(#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.