@@ -105,25 +105,41 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
105105 assertValidValueDim (value, dim);
106106#endif // NDEBUG
107107
108+ auto getPosExpr = [&](int64_t pos) {
109+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () &&
110+ " invalid position" );
111+ return pos < cstr.getNumDimVars ()
112+ ? builder.getAffineDimExpr (pos)
113+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
114+ };
115+
116+ // If the value/dim is already mapped, return the corresponding expression
117+ // directly.
118+ ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
119+ if (valueDimToPosition.contains (valueDim))
120+ return getPosExpr (getPos (value, dim));
121+
108122 auto shapedType = dyn_cast<ShapedType>(value.getType ());
109123 if (shapedType) {
110- // Static dimension: return constant directly.
111- if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
124+ // Static dimension: add EQ bound and return expression without pushing the
125+ // dim onto the worklist.
126+ if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim)) {
127+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
128+ bound (value)[*dim] == shapedType.getDimSize (*dim);
112129 return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
130+ }
113131 } else {
114- // Constant index value: return directly.
115- if (auto constInt = ::getConstantIntValue (value))
132+ // Constant index value: add EQ bound and return expression without pushing
133+ // the value onto the worklist.
134+ if (auto constInt = ::getConstantIntValue (value)) {
135+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
136+ bound (value) == *constInt;
116137 return builder.getAffineConstantExpr (*constInt);
138+ }
117139 }
118140
119- // Dynamic value: add to constraint set.
120- ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
121- if (!valueDimToPosition.contains (valueDim))
122- (void )insert (value, dim);
123- int64_t pos = getPos (value, dim);
124- return pos < cstr.getNumDimVars ()
125- ? builder.getAffineDimExpr (pos)
126- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
141+ // Dynamic value/dim: add to worklist.
142+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
127143}
128144
129145AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -140,7 +156,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
140156
141157int64_t ValueBoundsConstraintSet::insert (Value value,
142158 std::optional<int64_t > dim,
143- bool isSymbol) {
159+ bool isSymbol, bool addToWorklist ) {
144160#ifndef NDEBUG
145161 assertValidValueDim (value, dim);
146162#endif // NDEBUG
@@ -155,7 +171,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
155171 if (positionToValueDim[i].has_value ())
156172 valueDimToPosition[*positionToValueDim[i]] = i;
157173
158- worklist.push (pos);
174+ if (addToWorklist) {
175+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
176+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
177+ worklist.push (pos);
178+ }
179+
159180 return pos;
160181}
161182
@@ -191,7 +212,8 @@ static Operation *getOwnerOfValue(Value value) {
191212 return value.getDefiningOp ();
192213}
193214
194- void ValueBoundsConstraintSet::processWorklist (StopConditionFn stopCondition) {
215+ void ValueBoundsConstraintSet::processWorklist () {
216+ LLVM_DEBUG (llvm::dbgs () << " Processing value bounds worklist...\n " );
195217 while (!worklist.empty ()) {
196218 int64_t pos = worklist.front ();
197219 worklist.pop ();
@@ -212,20 +234,29 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
212234
213235 // Do not process any further if the stop condition is met.
214236 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional (dim);
215- if (stopCondition (value, maybeDim))
237+ if (currentStopCondition (value, maybeDim)) {
238+ LLVM_DEBUG (llvm::dbgs () << " Stop condition met for: " << value
239+ << " (dim: " << maybeDim << " )\n " );
216240 continue ;
241+ }
217242
218243 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
219244 // the worklist.
220245 auto valueBoundsOp =
221246 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue (value));
247+ LLVM_DEBUG (llvm::dbgs ()
248+ << " Query value bounds for: " << value
249+ << " (owner: " << getOwnerOfValue (value)->getName () << " )\n " );
222250 if (valueBoundsOp) {
223251 if (dim == kIndexValue ) {
224252 valueBoundsOp.populateBoundsForIndexValue (value, *this );
225253 } else {
226254 valueBoundsOp.populateBoundsForShapedValueDim (value, dim, *this );
227255 }
228256 continue ;
257+ } else {
258+ LLVM_DEBUG (llvm::dbgs ()
259+ << " --> ValueBoundsOpInterface not implemented\n " );
229260 }
230261
231262 // If the op does not implement `ValueBoundsOpInterface`, check if it
@@ -301,7 +332,8 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
301332 ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
302333 ValueBoundsConstraintSet cstr (value.getContext ());
303334 int64_t pos = cstr.insert (value, dim, /* isSymbol=*/ false );
304- cstr.processWorklist (stopCondition);
335+ cstr.currentStopCondition = stopCondition;
336+ cstr.processWorklist ();
305337
306338 // Project out all variables (apart from `valueDim`) that do not match the
307339 // stop condition.
@@ -494,14 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
494526 // Process the backward slice of `operands` (i.e., reverse use-def chain)
495527 // until `stopCondition` is met.
496528 if (stopCondition) {
497- cstr.processWorklist (stopCondition);
529+ cstr.currentStopCondition = stopCondition;
530+ cstr.processWorklist ();
498531 } else {
499532 // No stop condition specified: Keep adding constraints until a bound could
500533 // be computed.
501- cstr.processWorklist (
502- /* stopCondition=*/ [&](Value v, std::optional<int64_t > dim) {
503- return cstr.cstr .getConstantBound64 (type, pos).has_value ();
504- });
534+ auto stopCondFn = [&](Value v, std::optional<int64_t > dim) {
535+ return cstr.cstr .getConstantBound64 (type, pos).has_value ();
536+ };
537+ cstr.currentStopCondition = stopCondFn;
538+ cstr.processWorklist ();
505539 }
506540
507541 // Compute constant bound for `valueDim`.
@@ -538,6 +572,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
538572 {{value1, dim1}, {value2, dim2}});
539573}
540574
575+ void ValueBoundsConstraintSet::populateConstraints (Value value,
576+ std::optional<int64_t > dim) {
577+ // `getExpr` pushes the value/dim onto the worklist (unless it was already
578+ // analyzed).
579+ (void )getExpr (value, dim);
580+ // Process all values/dims on the worklist. This may traverse and analyze
581+ // additional IR, depending the current stop function.
582+ processWorklist ();
583+ }
584+
585+ bool ValueBoundsConstraintSet::compare (Value value1,
586+ std::optional<int64_t > dim1,
587+ ComparisonOperator cmp, Value value2,
588+ std::optional<int64_t > dim2) {
589+ // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
590+ // hold.
591+ //
592+ // Example for ComparisonOperator::LE and index-typed values: We would like to
593+ // prove that value1 <= value2. Proof by contradiction: add the inverse
594+ // relation (value1 > value2) to the constraint set and check if the resulting
595+ // constraint set is "empty" (i.e. has no solution). In that case,
596+ // value1 > value2 must be incorrect and we can deduce that value1 <= value2
597+ // holds.
598+
599+ // We cannot use prove anything if the constraint set is already empty.
600+ if (cstr.isEmpty ()) {
601+ LLVM_DEBUG (
602+ llvm::dbgs ()
603+ << " cannot compare value/dims: constraint system is already empty" );
604+ return false ;
605+ }
606+
607+ // EQ can be expressed as LE and GE.
608+ if (cmp == EQ)
609+ return compare (value1, dim1, ComparisonOperator::LE, value2, dim2) &&
610+ compare (value1, dim1, ComparisonOperator::GE, value2, dim2);
611+
612+ // Construct inequality. For the above example: value1 > value2.
613+ // `IntegerRelation` inequalities are expressed in the "flattened" form and
614+ // with ">= 0". I.e., value1 - value2 - 1 >= 0.
615+ SmallVector<int64_t > eq (cstr.getNumDimAndSymbolVars () + 1 , 0 );
616+ if (cmp == LT || cmp == LE) {
617+ eq[getPos (value1, dim1)]++;
618+ eq[getPos (value2, dim2)]--;
619+ } else if (cmp == GT || cmp == GE) {
620+ eq[getPos (value1, dim1)]--;
621+ eq[getPos (value2, dim2)]++;
622+ } else {
623+ llvm_unreachable (" unsupported comparison operator" );
624+ }
625+ if (cmp == LE || cmp == GE)
626+ eq[cstr.getNumDimAndSymbolVars ()] -= 1 ;
627+
628+ // Add inequality to the constraint set and check if it made the constraint
629+ // set empty.
630+ int64_t ineqPos = cstr.getNumInequalities ();
631+ cstr.addInequality (eq);
632+ bool isEmpty = cstr.isEmpty ();
633+ cstr.removeInequality (ineqPos);
634+ return isEmpty;
635+ }
636+
541637FailureOr<bool >
542638ValueBoundsConstraintSet::areEqual (Value value1, Value value2,
543639 std::optional<int64_t > dim1,
0 commit comments