@@ -66,6 +66,75 @@ struct MulIOpInterface
6666  }
6767};
6868
69+ struct  SelectOpInterface 
70+     : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
71+                                                    SelectOp> {
72+ 
73+   static  void  populateBounds (SelectOp selectOp, std::optional<int64_t > dim,
74+                              ValueBoundsConstraintSet &cstr) {
75+     Value value = selectOp.getResult ();
76+     Value condition = selectOp.getCondition ();
77+     Value trueValue = selectOp.getTrueValue ();
78+     Value falseValue = selectOp.getFalseValue ();
79+ 
80+     if  (isa<ShapedType>(condition.getType ())) {
81+       //  If the condition is a shaped type, the condition is applied
82+       //  element-wise. All three operands must have the same shape.
83+       cstr.bound (value)[*dim] == cstr.getExpr (trueValue, dim);
84+       cstr.bound (value)[*dim] == cstr.getExpr (falseValue, dim);
85+       cstr.bound (value)[*dim] == cstr.getExpr (condition, dim);
86+       return ;
87+     }
88+ 
89+     //  Populate constraints for the true/false values (and all values on the
90+     //  backward slice, as long as the current stop condition is not satisfied).
91+     cstr.populateConstraints (trueValue, dim);
92+     cstr.populateConstraints (falseValue, dim);
93+     auto  boundsBuilder = cstr.bound (value);
94+     if  (dim)
95+       boundsBuilder[*dim];
96+ 
97+     //  Compare yielded values.
98+     //  If trueValue <= falseValue:
99+     //  * result <= falseValue
100+     //  * result >= trueValue
101+     if  (cstr.compare (trueValue, dim,
102+                      ValueBoundsConstraintSet::ComparisonOperator::LE,
103+                      falseValue, dim)) {
104+       if  (dim) {
105+         cstr.bound (value)[*dim] >= cstr.getExpr (trueValue, dim);
106+         cstr.bound (value)[*dim] <= cstr.getExpr (falseValue, dim);
107+       } else  {
108+         cstr.bound (value) >= trueValue;
109+         cstr.bound (value) <= falseValue;
110+       }
111+     }
112+     //  If falseValue <= trueValue:
113+     //  * result <= trueValue
114+     //  * result >= falseValue
115+     if  (cstr.compare (falseValue, dim,
116+                      ValueBoundsConstraintSet::ComparisonOperator::LE,
117+                      trueValue, dim)) {
118+       if  (dim) {
119+         cstr.bound (value)[*dim] >= cstr.getExpr (falseValue, dim);
120+         cstr.bound (value)[*dim] <= cstr.getExpr (trueValue, dim);
121+       } else  {
122+         cstr.bound (value) >= falseValue;
123+         cstr.bound (value) <= trueValue;
124+       }
125+     }
126+   }
127+ 
128+   void  populateBoundsForIndexValue (Operation *op, Value value,
129+                                    ValueBoundsConstraintSet &cstr) const  {
130+     populateBounds (cast<SelectOp>(op), /* dim=*/ nullopt , cstr);
131+   }
132+ 
133+   void  populateBoundsForShapedValueDim (Operation *op, Value value, int64_t  dim,
134+                                        ValueBoundsConstraintSet &cstr) const  {
135+     populateBounds (cast<SelectOp>(op), dim, cstr);
136+   }
137+ };
69138} //  namespace
70139} //  namespace arith
71140} //  namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
77146    arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
78147    arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
79148    arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
149+     arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
80150  });
81151}
0 commit comments