@@ -66,6 +66,75 @@ struct MulIOpInterface
6666 }
6767};
6868
69+ struct SelectOpInterface
70+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
71+ SelectOp> {
72+
73+ static void populateBounds (SelectOp selectOp, std::optional<int64_t > dim,
74+ ValueBoundsConstraintSet &cstr) {
75+ Value value = selectOp.getResult ();
76+ Value condition = selectOp.getCondition ();
77+ Value trueValue = selectOp.getTrueValue ();
78+ Value falseValue = selectOp.getFalseValue ();
79+
80+ if (isa<ShapedType>(condition.getType ())) {
81+ // If the condition is a shaped type, the condition is applied
82+ // element-wise. All three operands must have the same shape.
83+ cstr.bound (value)[*dim] == cstr.getExpr (trueValue, dim);
84+ cstr.bound (value)[*dim] == cstr.getExpr (falseValue, dim);
85+ cstr.bound (value)[*dim] == cstr.getExpr (condition, dim);
86+ return ;
87+ }
88+
89+ // Populate constraints for the true/false values (and all values on the
90+ // backward slice, as long as the current stop condition is not satisfied).
91+ cstr.populateConstraints (trueValue, dim);
92+ cstr.populateConstraints (falseValue, dim);
93+ auto boundsBuilder = cstr.bound (value);
94+ if (dim)
95+ boundsBuilder[*dim];
96+
97+ // Compare yielded values.
98+ // If trueValue <= falseValue:
99+ // * result <= falseValue
100+ // * result >= trueValue
101+ if (cstr.compare (trueValue, dim,
102+ ValueBoundsConstraintSet::ComparisonOperator::LE,
103+ falseValue, dim)) {
104+ if (dim) {
105+ cstr.bound (value)[*dim] >= cstr.getExpr (trueValue, dim);
106+ cstr.bound (value)[*dim] <= cstr.getExpr (falseValue, dim);
107+ } else {
108+ cstr.bound (value) >= trueValue;
109+ cstr.bound (value) <= falseValue;
110+ }
111+ }
112+ // If falseValue <= trueValue:
113+ // * result <= trueValue
114+ // * result >= falseValue
115+ if (cstr.compare (falseValue, dim,
116+ ValueBoundsConstraintSet::ComparisonOperator::LE,
117+ trueValue, dim)) {
118+ if (dim) {
119+ cstr.bound (value)[*dim] >= cstr.getExpr (falseValue, dim);
120+ cstr.bound (value)[*dim] <= cstr.getExpr (trueValue, dim);
121+ } else {
122+ cstr.bound (value) >= falseValue;
123+ cstr.bound (value) <= trueValue;
124+ }
125+ }
126+ }
127+
128+ void populateBoundsForIndexValue (Operation *op, Value value,
129+ ValueBoundsConstraintSet &cstr) const {
130+ populateBounds (cast<SelectOp>(op), /* dim=*/ std::nullopt , cstr);
131+ }
132+
133+ void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
134+ ValueBoundsConstraintSet &cstr) const {
135+ populateBounds (cast<SelectOp>(op), dim, cstr);
136+ }
137+ };
69138} // namespace
70139} // namespace arith
71140} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
77146 arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
78147 arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
79148 arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
149+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
80150 });
81151}
0 commit comments