Skip to content

Commit e55463c

Browse files
[mlir][Arith] ValueBoundsOpInterface: Support arith.select (#86383)
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
1 parent 0ba3e96 commit e55463c

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,75 @@ struct MulIOpInterface
7575
}
7676
};
7777

78+
struct SelectOpInterface
79+
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80+
SelectOp> {
81+
82+
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
83+
ValueBoundsConstraintSet &cstr) {
84+
Value value = selectOp.getResult();
85+
Value condition = selectOp.getCondition();
86+
Value trueValue = selectOp.getTrueValue();
87+
Value falseValue = selectOp.getFalseValue();
88+
89+
if (isa<ShapedType>(condition.getType())) {
90+
// If the condition is a shaped type, the condition is applied
91+
// element-wise. All three operands must have the same shape.
92+
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
93+
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
94+
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
95+
return;
96+
}
97+
98+
// Populate constraints for the true/false values (and all values on the
99+
// backward slice, as long as the current stop condition is not satisfied).
100+
cstr.populateConstraints(trueValue, dim);
101+
cstr.populateConstraints(falseValue, dim);
102+
auto boundsBuilder = cstr.bound(value);
103+
if (dim)
104+
boundsBuilder[*dim];
105+
106+
// Compare yielded values.
107+
// If trueValue <= falseValue:
108+
// * result <= falseValue
109+
// * result >= trueValue
110+
if (cstr.compare(trueValue, dim,
111+
ValueBoundsConstraintSet::ComparisonOperator::LE,
112+
falseValue, dim)) {
113+
if (dim) {
114+
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
115+
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
116+
} else {
117+
cstr.bound(value) >= trueValue;
118+
cstr.bound(value) <= falseValue;
119+
}
120+
}
121+
// If falseValue <= trueValue:
122+
// * result <= trueValue
123+
// * result >= falseValue
124+
if (cstr.compare(falseValue, dim,
125+
ValueBoundsConstraintSet::ComparisonOperator::LE,
126+
trueValue, dim)) {
127+
if (dim) {
128+
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
129+
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
130+
} else {
131+
cstr.bound(value) >= falseValue;
132+
cstr.bound(value) <= trueValue;
133+
}
134+
}
135+
}
136+
137+
void populateBoundsForIndexValue(Operation *op, Value value,
138+
ValueBoundsConstraintSet &cstr) const {
139+
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
140+
}
141+
142+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
143+
ValueBoundsConstraintSet &cstr) const {
144+
populateBounds(cast<SelectOp>(op), dim, cstr);
145+
}
146+
};
78147
} // namespace
79148
} // namespace arith
80149
} // namespace mlir
@@ -86,5 +155,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
86155
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
87156
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
88157
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
158+
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
89159
});
90160
}

mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
7474
%0 = "test.reify_bound"(%c5) : (index) -> (index)
7575
return %0 : index
7676
}
77+
78+
// -----
79+
80+
// CHECK-LABEL: func @arith_select(
81+
func.func @arith_select(%c: i1) -> (index, index) {
82+
// CHECK: arith.constant 5 : index
83+
%c5 = arith.constant 5 : index
84+
// CHECK: arith.constant 9 : index
85+
%c9 = arith.constant 9 : index
86+
%r = arith.select %c, %c5, %c9 : index
87+
// CHECK: %[[c5:.*]] = arith.constant 5 : index
88+
// CHECK: %[[c10:.*]] = arith.constant 10 : index
89+
%0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
90+
%1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
91+
// CHECK: return %[[c5]], %[[c10]]
92+
return %0, %1 : index, index
93+
}
94+
95+
// -----
96+
97+
// CHECK-LABEL: func @arith_select_elementwise(
98+
// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
99+
func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
100+
%r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
101+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
102+
// CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
103+
%0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
104+
: (tensor<?xf32>) -> (index)
105+
// CHECK: return %[[dim]]
106+
return %0 : index
107+
}

0 commit comments

Comments
 (0)