Skip to content

Commit 680d04d

Browse files
[mlir][Arith] ValueBoundsOpInterface: Support arith.select
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
1 parent 3c4adb5 commit 680d04d

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,75 @@ struct MulIOpInterface
6666
}
6767
};
6868

69+
struct SelectOpInterface
70+
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
71+
SelectOp> {
72+
73+
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
74+
ValueBoundsConstraintSet &cstr) {
75+
Value value = selectOp.getResult();
76+
Value condition = selectOp.getCondition();
77+
Value trueValue = selectOp.getTrueValue();
78+
Value falseValue = selectOp.getFalseValue();
79+
80+
if (isa<ShapedType>(condition.getType())) {
81+
// If the condition is a shaped type, the condition is applied
82+
// element-wise. All three operands must have the same shape.
83+
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
84+
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
85+
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
86+
return;
87+
}
88+
89+
// Populate constraints for the true/false values (and all values on the
90+
// backward slice, as long as the current stop condition is not satisfied).
91+
cstr.populateConstraints(trueValue, dim);
92+
cstr.populateConstraints(falseValue, dim);
93+
auto boundsBuilder = cstr.bound(value);
94+
if (dim)
95+
boundsBuilder[*dim];
96+
97+
// Compare yielded values.
98+
// If trueValue <= falseValue:
99+
// * result <= falseValue
100+
// * result >= trueValue
101+
if (cstr.compare(trueValue, dim,
102+
ValueBoundsConstraintSet::ComparisonOperator::LE,
103+
falseValue, dim)) {
104+
if (dim) {
105+
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
106+
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
107+
} else {
108+
cstr.bound(value) >= trueValue;
109+
cstr.bound(value) <= falseValue;
110+
}
111+
}
112+
// If falseValue <= trueValue:
113+
// * result <= trueValue
114+
// * result >= falseValue
115+
if (cstr.compare(falseValue, dim,
116+
ValueBoundsConstraintSet::ComparisonOperator::LE,
117+
trueValue, dim)) {
118+
if (dim) {
119+
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
120+
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
121+
} else {
122+
cstr.bound(value) >= falseValue;
123+
cstr.bound(value) <= trueValue;
124+
}
125+
}
126+
}
127+
128+
void populateBoundsForIndexValue(Operation *op, Value value,
129+
ValueBoundsConstraintSet &cstr) const {
130+
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
131+
}
132+
133+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
134+
ValueBoundsConstraintSet &cstr) const {
135+
populateBounds(cast<SelectOp>(op), dim, cstr);
136+
}
137+
};
69138
} // namespace
70139
} // namespace arith
71140
} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
77146
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
78147
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
79148
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
149+
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
80150
});
81151
}

mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
7474
%0 = "test.reify_bound"(%c5) : (index) -> (index)
7575
return %0 : index
7676
}
77+
78+
// -----
79+
80+
// CHECK-LABEL: func @arith_select(
81+
func.func @arith_select(%c: i1) -> (index, index) {
82+
// CHECK: arith.constant 5 : index
83+
%c5 = arith.constant 5 : index
84+
// CHECK: arith.constant 9 : index
85+
%c9 = arith.constant 9 : index
86+
%r = arith.select %c, %c5, %c9 : index
87+
// CHECK: %[[c5:.*]] = arith.constant 5 : index
88+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
89+
%0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
90+
%1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
91+
// CHECK: return %[[c5]], %[[c10]]
92+
return %0, %1 : index, index
93+
}
94+
95+
// -----
96+
97+
// CHECK-LABEL: func @arith_select_elementwise(
98+
// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
99+
func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
100+
%r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
101+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
102+
// CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
103+
%0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
104+
: (tensor<?xf32>) -> (index)
105+
// CHECK: return %[[dim]]
106+
return %0 : index
107+
}

0 commit comments

Comments
 (0)