@@ -105,25 +105,41 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
105
105
assertValidValueDim (value, dim);
106
106
#endif // NDEBUG
107
107
108
+ auto getPosExpr = [&](int64_t pos) {
109
+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () &&
110
+ " invalid position" );
111
+ return pos < cstr.getNumDimVars ()
112
+ ? builder.getAffineDimExpr (pos)
113
+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
114
+ };
115
+
116
+ // If the value/dim is already mapped, return the corresponding expression
117
+ // directly.
118
+ ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
119
+ if (valueDimToPosition.contains (valueDim))
120
+ return getPosExpr (getPos (value, dim));
121
+
108
122
auto shapedType = dyn_cast<ShapedType>(value.getType ());
109
123
if (shapedType) {
110
- // Static dimension: return constant directly.
111
- if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
124
+ // Static dimension: add EQ bound and return expression without pushing the
125
+ // dim onto the worklist.
126
+ if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim)) {
127
+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
128
+ bound (value)[*dim] == shapedType.getDimSize (*dim);
112
129
return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
130
+ }
113
131
} else {
114
- // Constant index value: return directly.
115
- if (auto constInt = ::getConstantIntValue (value))
132
+ // Constant index value: add EQ bound and return expression without pushing
133
+ // the value onto the worklist.
134
+ if (auto constInt = ::getConstantIntValue (value)) {
135
+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
136
+ bound (value) == *constInt;
116
137
return builder.getAffineConstantExpr (*constInt);
138
+ }
117
139
}
118
140
119
- // Dynamic value: add to constraint set.
120
- ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
121
- if (!valueDimToPosition.contains (valueDim))
122
- (void )insert (value, dim);
123
- int64_t pos = getPos (value, dim);
124
- return pos < cstr.getNumDimVars ()
125
- ? builder.getAffineDimExpr (pos)
126
- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
141
+ // Dynamic value/dim: add to worklist.
142
+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
127
143
}
128
144
129
145
AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -140,7 +156,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
140
156
141
157
int64_t ValueBoundsConstraintSet::insert (Value value,
142
158
std::optional<int64_t > dim,
143
- bool isSymbol) {
159
+ bool isSymbol, bool addToWorklist ) {
144
160
#ifndef NDEBUG
145
161
assertValidValueDim (value, dim);
146
162
#endif // NDEBUG
@@ -155,7 +171,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
155
171
if (positionToValueDim[i].has_value ())
156
172
valueDimToPosition[*positionToValueDim[i]] = i;
157
173
158
- worklist.push (pos);
174
+ if (addToWorklist) {
175
+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
176
+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
177
+ worklist.push (pos);
178
+ }
179
+
159
180
return pos;
160
181
}
161
182
@@ -191,7 +212,8 @@ static Operation *getOwnerOfValue(Value value) {
191
212
return value.getDefiningOp ();
192
213
}
193
214
194
- void ValueBoundsConstraintSet::processWorklist (StopConditionFn stopCondition) {
215
+ void ValueBoundsConstraintSet::processWorklist () {
216
+ LLVM_DEBUG (llvm::dbgs () << " Processing value bounds worklist...\n " );
195
217
while (!worklist.empty ()) {
196
218
int64_t pos = worklist.front ();
197
219
worklist.pop ();
@@ -212,20 +234,29 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
212
234
213
235
// Do not process any further if the stop condition is met.
214
236
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional (dim);
215
- if (stopCondition (value, maybeDim))
237
+ if (currentStopCondition (value, maybeDim)) {
238
+ LLVM_DEBUG (llvm::dbgs () << " Stop condition met for: " << value
239
+ << " (dim: " << maybeDim << " )\n " );
216
240
continue ;
241
+ }
217
242
218
243
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
219
244
// the worklist.
220
245
auto valueBoundsOp =
221
246
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue (value));
247
+ LLVM_DEBUG (llvm::dbgs ()
248
+ << " Query value bounds for: " << value
249
+ << " (owner: " << getOwnerOfValue (value)->getName () << " )\n " );
222
250
if (valueBoundsOp) {
223
251
if (dim == kIndexValue ) {
224
252
valueBoundsOp.populateBoundsForIndexValue (value, *this );
225
253
} else {
226
254
valueBoundsOp.populateBoundsForShapedValueDim (value, dim, *this );
227
255
}
228
256
continue ;
257
+ } else {
258
+ LLVM_DEBUG (llvm::dbgs ()
259
+ << " --> ValueBoundsOpInterface not implemented\n " );
229
260
}
230
261
231
262
// If the op does not implement `ValueBoundsOpInterface`, check if it
@@ -301,7 +332,8 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
301
332
ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
302
333
ValueBoundsConstraintSet cstr (value.getContext ());
303
334
int64_t pos = cstr.insert (value, dim, /* isSymbol=*/ false );
304
- cstr.processWorklist (stopCondition);
335
+ cstr.currentStopCondition = stopCondition;
336
+ cstr.processWorklist ();
305
337
306
338
// Project out all variables (apart from `valueDim`) that do not match the
307
339
// stop condition.
@@ -494,14 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
494
526
// Process the backward slice of `operands` (i.e., reverse use-def chain)
495
527
// until `stopCondition` is met.
496
528
if (stopCondition) {
497
- cstr.processWorklist (stopCondition);
529
+ cstr.currentStopCondition = stopCondition;
530
+ cstr.processWorklist ();
498
531
} else {
499
532
// No stop condition specified: Keep adding constraints until a bound could
500
533
// be computed.
501
- cstr.processWorklist (
502
- /* stopCondition=*/ [&](Value v, std::optional<int64_t > dim) {
503
- return cstr.cstr .getConstantBound64 (type, pos).has_value ();
504
- });
534
+ auto stopCondFn = [&](Value v, std::optional<int64_t > dim) {
535
+ return cstr.cstr .getConstantBound64 (type, pos).has_value ();
536
+ };
537
+ cstr.currentStopCondition = stopCondFn;
538
+ cstr.processWorklist ();
505
539
}
506
540
507
541
// Compute constant bound for `valueDim`.
@@ -538,6 +572,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
538
572
{{value1, dim1}, {value2, dim2}});
539
573
}
540
574
575
+ void ValueBoundsConstraintSet::populateConstraints (Value value,
576
+ std::optional<int64_t > dim) {
577
+ // `getExpr` pushes the value/dim onto the worklist (unless it was already
578
+ // analyzed).
579
+ (void )getExpr (value, dim);
580
+ // Process all values/dims on the worklist. This may traverse and analyze
581
+ // additional IR, depending the current stop function.
582
+ processWorklist ();
583
+ }
584
+
585
+ bool ValueBoundsConstraintSet::compare (Value value1,
586
+ std::optional<int64_t > dim1,
587
+ ComparisonOperator cmp, Value value2,
588
+ std::optional<int64_t > dim2) {
589
+ // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
590
+ // hold.
591
+ //
592
+ // Example for ComparisonOperator::LE and index-typed values: We would like to
593
+ // prove that value1 <= value2. Proof by contradiction: add the inverse
594
+ // relation (value1 > value2) to the constraint set and check if the resulting
595
+ // constraint set is "empty" (i.e. has no solution). In that case,
596
+ // value1 > value2 must be incorrect and we can deduce that value1 <= value2
597
+ // holds.
598
+
599
+ // We cannot use prove anything if the constraint set is already empty.
600
+ if (cstr.isEmpty ()) {
601
+ LLVM_DEBUG (
602
+ llvm::dbgs ()
603
+ << " cannot compare value/dims: constraint system is already empty" );
604
+ return false ;
605
+ }
606
+
607
+ // EQ can be expressed as LE and GE.
608
+ if (cmp == EQ)
609
+ return compare (value1, dim1, ComparisonOperator::LE, value2, dim2) &&
610
+ compare (value1, dim1, ComparisonOperator::GE, value2, dim2);
611
+
612
+ // Construct inequality. For the above example: value1 > value2.
613
+ // `IntegerRelation` inequalities are expressed in the "flattened" form and
614
+ // with ">= 0". I.e., value1 - value2 - 1 >= 0.
615
+ SmallVector<int64_t > eq (cstr.getNumDimAndSymbolVars () + 1 , 0 );
616
+ if (cmp == LT || cmp == LE) {
617
+ eq[getPos (value1, dim1)]++;
618
+ eq[getPos (value2, dim2)]--;
619
+ } else if (cmp == GT || cmp == GE) {
620
+ eq[getPos (value1, dim1)]--;
621
+ eq[getPos (value2, dim2)]++;
622
+ } else {
623
+ llvm_unreachable (" unsupported comparison operator" );
624
+ }
625
+ if (cmp == LE || cmp == GE)
626
+ eq[cstr.getNumDimAndSymbolVars ()] -= 1 ;
627
+
628
+ // Add inequality to the constraint set and check if it made the constraint
629
+ // set empty.
630
+ int64_t ineqPos = cstr.getNumInequalities ();
631
+ cstr.addInequality (eq);
632
+ bool isEmpty = cstr.isEmpty ();
633
+ cstr.removeInequality (ineqPos);
634
+ return isEmpty;
635
+ }
636
+
541
637
FailureOr<bool >
542
638
ValueBoundsConstraintSet::areEqual (Value value1, Value value2,
543
639
std::optional<int64_t > dim1,
0 commit comments