@@ -75,6 +75,75 @@ struct MulIOpInterface
75
75
}
76
76
};
77
77
78
+ struct SelectOpInterface
79
+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80
+ SelectOp> {
81
+
82
+ static void populateBounds (SelectOp selectOp, std::optional<int64_t > dim,
83
+ ValueBoundsConstraintSet &cstr) {
84
+ Value value = selectOp.getResult ();
85
+ Value condition = selectOp.getCondition ();
86
+ Value trueValue = selectOp.getTrueValue ();
87
+ Value falseValue = selectOp.getFalseValue ();
88
+
89
+ if (isa<ShapedType>(condition.getType ())) {
90
+ // If the condition is a shaped type, the condition is applied
91
+ // element-wise. All three operands must have the same shape.
92
+ cstr.bound (value)[*dim] == cstr.getExpr (trueValue, dim);
93
+ cstr.bound (value)[*dim] == cstr.getExpr (falseValue, dim);
94
+ cstr.bound (value)[*dim] == cstr.getExpr (condition, dim);
95
+ return ;
96
+ }
97
+
98
+ // Populate constraints for the true/false values (and all values on the
99
+ // backward slice, as long as the current stop condition is not satisfied).
100
+ cstr.populateConstraints (trueValue, dim);
101
+ cstr.populateConstraints (falseValue, dim);
102
+ auto boundsBuilder = cstr.bound (value);
103
+ if (dim)
104
+ boundsBuilder[*dim];
105
+
106
+ // Compare yielded values.
107
+ // If trueValue <= falseValue:
108
+ // * result <= falseValue
109
+ // * result >= trueValue
110
+ if (cstr.compare (trueValue, dim,
111
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
112
+ falseValue, dim)) {
113
+ if (dim) {
114
+ cstr.bound (value)[*dim] >= cstr.getExpr (trueValue, dim);
115
+ cstr.bound (value)[*dim] <= cstr.getExpr (falseValue, dim);
116
+ } else {
117
+ cstr.bound (value) >= trueValue;
118
+ cstr.bound (value) <= falseValue;
119
+ }
120
+ }
121
+ // If falseValue <= trueValue:
122
+ // * result <= trueValue
123
+ // * result >= falseValue
124
+ if (cstr.compare (falseValue, dim,
125
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
126
+ trueValue, dim)) {
127
+ if (dim) {
128
+ cstr.bound (value)[*dim] >= cstr.getExpr (falseValue, dim);
129
+ cstr.bound (value)[*dim] <= cstr.getExpr (trueValue, dim);
130
+ } else {
131
+ cstr.bound (value) >= falseValue;
132
+ cstr.bound (value) <= trueValue;
133
+ }
134
+ }
135
+ }
136
+
137
+ void populateBoundsForIndexValue (Operation *op, Value value,
138
+ ValueBoundsConstraintSet &cstr) const {
139
+ populateBounds (cast<SelectOp>(op), /* dim=*/ std::nullopt, cstr);
140
+ }
141
+
142
+ void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
143
+ ValueBoundsConstraintSet &cstr) const {
144
+ populateBounds (cast<SelectOp>(op), dim, cstr);
145
+ }
146
+ };
78
147
} // namespace
79
148
} // namespace arith
80
149
} // namespace mlir
@@ -86,5 +155,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
86
155
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
87
156
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
88
157
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
158
+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
89
159
});
90
160
}
0 commit comments