@@ -66,6 +66,75 @@ struct MulIOpInterface
66
66
}
67
67
};
68
68
69
+ struct SelectOpInterface
70
+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
71
+ SelectOp> {
72
+
73
+ static void populateBounds (SelectOp selectOp, std::optional<int64_t > dim,
74
+ ValueBoundsConstraintSet &cstr) {
75
+ Value value = selectOp.getResult ();
76
+ Value condition = selectOp.getCondition ();
77
+ Value trueValue = selectOp.getTrueValue ();
78
+ Value falseValue = selectOp.getFalseValue ();
79
+
80
+ if (isa<ShapedType>(condition.getType ())) {
81
+ // If the condition is a shaped type, the condition is applied
82
+ // element-wise. All three operands must have the same shape.
83
+ cstr.bound (value)[*dim] == cstr.getExpr (trueValue, dim);
84
+ cstr.bound (value)[*dim] == cstr.getExpr (falseValue, dim);
85
+ cstr.bound (value)[*dim] == cstr.getExpr (condition, dim);
86
+ return ;
87
+ }
88
+
89
+ // Populate constraints for the true/false values (and all values on the
90
+ // backward slice, as long as the current stop condition is not satisfied).
91
+ cstr.populateConstraints (trueValue, dim);
92
+ cstr.populateConstraints (falseValue, dim);
93
+ auto boundsBuilder = cstr.bound (value);
94
+ if (dim)
95
+ boundsBuilder[*dim];
96
+
97
+ // Compare yielded values.
98
+ // If trueValue <= falseValue:
99
+ // * result <= falseValue
100
+ // * result >= trueValue
101
+ if (cstr.compare (trueValue, dim,
102
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
103
+ falseValue, dim)) {
104
+ if (dim) {
105
+ cstr.bound (value)[*dim] >= cstr.getExpr (trueValue, dim);
106
+ cstr.bound (value)[*dim] <= cstr.getExpr (falseValue, dim);
107
+ } else {
108
+ cstr.bound (value) >= trueValue;
109
+ cstr.bound (value) <= falseValue;
110
+ }
111
+ }
112
+ // If falseValue <= trueValue:
113
+ // * result <= trueValue
114
+ // * result >= falseValue
115
+ if (cstr.compare (falseValue, dim,
116
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
117
+ trueValue, dim)) {
118
+ if (dim) {
119
+ cstr.bound (value)[*dim] >= cstr.getExpr (falseValue, dim);
120
+ cstr.bound (value)[*dim] <= cstr.getExpr (trueValue, dim);
121
+ } else {
122
+ cstr.bound (value) >= falseValue;
123
+ cstr.bound (value) <= trueValue;
124
+ }
125
+ }
126
+ }
127
+
128
+ void populateBoundsForIndexValue (Operation *op, Value value,
129
+ ValueBoundsConstraintSet &cstr) const {
130
+ populateBounds (cast<SelectOp>(op), /* dim=*/ std::nullopt, cstr);
131
+ }
132
+
133
+ void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
134
+ ValueBoundsConstraintSet &cstr) const {
135
+ populateBounds (cast<SelectOp>(op), dim, cstr);
136
+ }
137
+ };
69
138
} // namespace
70
139
} // namespace arith
71
140
} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
77
146
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
78
147
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
79
148
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
149
+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
80
150
});
81
151
}
0 commit comments