@@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
110
110
assertValidValueDim (value, dim);
111
111
#endif // NDEBUG
112
112
113
+ // Check if the value/dim is statically known. In that case, an affine
114
+ // constant expression should be returned. This allows us to support
115
+ // multiplications with constants. (Multiplications of two columns in the
116
+ // constraint set is not supported.)
117
+ std::optional<int64_t > constSize = std::nullopt;
113
118
auto shapedType = dyn_cast<ShapedType>(value.getType ());
114
119
if (shapedType) {
115
- // Static dimension: return constant directly.
116
120
if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
117
- return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
118
- } else {
119
- // Constant index value: return directly.
120
- if (auto constInt = ::getConstantIntValue (value))
121
- return builder.getAffineConstantExpr (*constInt);
121
+ constSize = shapedType.getDimSize (*dim);
122
+ } else if (auto constInt = ::getConstantIntValue (value)) {
123
+ constSize = *constInt;
122
124
}
123
125
124
- // Dynamic value: add to constraint set.
126
+ // If the value/dim is already mapped, return the corresponding expression
127
+ // directly.
125
128
ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
126
- if (!valueDimToPosition.contains (valueDim))
127
- (void )insert (value, dim);
128
- int64_t pos = getPos (value, dim);
129
- return pos < cstr.getNumDimVars ()
130
- ? builder.getAffineDimExpr (pos)
131
- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
129
+ if (valueDimToPosition.contains (valueDim)) {
130
+ // If it is a constant, return an affine constant expression. Otherwise,
131
+ // return an affine expression that represents the respective column in the
132
+ // constraint set.
133
+ if (constSize)
134
+ return builder.getAffineConstantExpr (*constSize);
135
+ return getPosExpr (getPos (value, dim));
136
+ }
137
+
138
+ if (constSize) {
139
+ // Constant index value/dim: add column to the constraint set, add EQ bound
140
+ // and return an affine constant expression without pushing the newly added
141
+ // column to the worklist.
142
+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
143
+ if (shapedType)
144
+ bound (value)[*dim] == *constSize;
145
+ else
146
+ bound (value) == *constSize;
147
+ return builder.getAffineConstantExpr (*constSize);
148
+ }
149
+
150
+ // Dynamic value/dim: insert column to the constraint set and put it on the
151
+ // worklist. Return an affine expression that represents the newly inserted
152
+ // column in the constraint set.
153
+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
132
154
}
133
155
134
156
AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
145
167
146
168
int64_t ValueBoundsConstraintSet::insert (Value value,
147
169
std::optional<int64_t > dim,
148
- bool isSymbol) {
170
+ bool isSymbol, bool addToWorklist ) {
149
171
#ifndef NDEBUG
150
172
assertValidValueDim (value, dim);
151
173
#endif // NDEBUG
@@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
160
182
if (positionToValueDim[i].has_value ())
161
183
valueDimToPosition[*positionToValueDim[i]] = i;
162
184
163
- worklist.push (pos);
185
+ if (addToWorklist) {
186
+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
187
+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
188
+ worklist.push (pos);
189
+ }
190
+
164
191
return pos;
165
192
}
166
193
@@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
190
217
return it->second ;
191
218
}
192
219
220
+ AffineExpr ValueBoundsConstraintSet::getPosExpr (int64_t pos) {
221
+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () && " invalid position" );
222
+ return pos < cstr.getNumDimVars ()
223
+ ? builder.getAffineDimExpr (pos)
224
+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
225
+ }
226
+
193
227
static Operation *getOwnerOfValue (Value value) {
194
228
if (auto bbArg = dyn_cast<BlockArgument>(value))
195
229
return bbArg.getOwner ()->getParentOp ();
@@ -492,15 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
492
526
493
527
// Default stop condition if none was specified: Keep adding constraints until
494
528
// a bound could be computed.
495
- int64_t pos;
529
+ int64_t pos = 0 ;
496
530
auto defaultStopCondition = [&](Value v, std::optional<int64_t > dim,
497
531
ValueBoundsConstraintSet &cstr) {
498
532
return cstr.cstr .getConstantBound64 (type, pos).has_value ();
499
533
};
500
534
501
535
ValueBoundsConstraintSet cstr (
502
536
map.getContext (), stopCondition ? stopCondition : defaultStopCondition);
503
- cstr.populateConstraintsSet (map, operands, &pos);
537
+ pos = cstr.populateConstraints (map, operands);
538
+ assert (pos == 0 && " expected `map` is the first column" );
504
539
505
540
// Compute constant bound for `valueDim`.
506
541
int64_t ubAdjustment = closedUB ? 0 : 1 ;
@@ -509,29 +544,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
509
544
return failure ();
510
545
}
511
546
512
- int64_t
513
- ValueBoundsConstraintSet::populateConstraintsSet (Value value,
514
- std::optional<int64_t > dim) {
547
+ void ValueBoundsConstraintSet::populateConstraints (Value value,
548
+ std::optional<int64_t > dim) {
515
549
#ifndef NDEBUG
516
550
assertValidValueDim (value, dim);
517
551
#endif // NDEBUG
518
552
519
- AffineMap map =
520
- AffineMap::get (/* dimCount=*/ 1 , /* symbolCount=*/ 0 ,
521
- Builder (value.getContext ()).getAffineDimExpr (0 ));
522
- return populateConstraintsSet (map, {{value, dim}});
553
+ // `getExpr` pushes the value/dim onto the worklist (unless it was already
554
+ // analyzed).
555
+ (void )getExpr (value, dim);
556
+ // Process all values/dims on the worklist. This may traverse and analyze
557
+ // additional IR, depending the current stop function.
558
+ processWorklist ();
523
559
}
524
560
525
- int64_t ValueBoundsConstraintSet::populateConstraintsSet (AffineMap map,
526
- ValueDimList operands,
527
- int64_t *posOut) {
561
+ int64_t ValueBoundsConstraintSet::populateConstraints (AffineMap map,
562
+ ValueDimList operands) {
528
563
assert (map.getNumResults () == 1 && " expected affine map with one result" );
529
564
int64_t pos = insert (/* isSymbol=*/ false );
530
- if (posOut)
531
- *posOut = pos;
532
565
533
566
// Add map and operands to the constraint set. Dimensions are converted to
534
- // symbols. All operands are added to the worklist.
567
+ // symbols. All operands are added to the worklist (unless they were already
568
+ // processed).
535
569
auto mapper = [&](std::pair<Value, std::optional<int64_t >> v) {
536
570
return getExpr (v.first , v.second );
537
571
};
@@ -566,6 +600,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
566
600
{{value1, dim1}, {value2, dim2}});
567
601
}
568
602
603
+ bool ValueBoundsConstraintSet::compare (Value lhs, std::optional<int64_t > lhsDim,
604
+ ComparisonOperator cmp, Value rhs,
605
+ std::optional<int64_t > rhsDim) {
606
+ // This function returns "true" if "lhs CMP rhs" is proven to hold.
607
+ //
608
+ // Example for ComparisonOperator::LE and index-typed values: We would like to
609
+ // prove that lhs <= rhs. Proof by contradiction: add the inverse
610
+ // relation (lhs > rhs) to the constraint set and check if the resulting
611
+ // constraint set is "empty" (i.e. has no solution). In that case,
612
+ // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
613
+
614
+ // We cannot prove anything if the constraint set is already empty.
615
+ if (cstr.isEmpty ()) {
616
+ LLVM_DEBUG (
617
+ llvm::dbgs ()
618
+ << " cannot compare value/dims: constraint system is already empty" );
619
+ return false ;
620
+ }
621
+
622
+ // EQ can be expressed as LE and GE.
623
+ if (cmp == EQ)
624
+ return compare (lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
625
+ compare (lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
626
+
627
+ // Construct inequality. For the above example: lhs > rhs.
628
+ // `IntegerRelation` inequalities are expressed in the "flattened" form and
629
+ // with ">= 0". I.e., lhs - rhs - 1 >= 0.
630
+ SmallVector<int64_t > eq (cstr.getNumDimAndSymbolVars () + 1 , 0 );
631
+ if (cmp == LT || cmp == LE) {
632
+ ++eq[getPos (lhs, lhsDim)];
633
+ --eq[getPos (rhs, rhsDim)];
634
+ } else if (cmp == GT || cmp == GE) {
635
+ --eq[getPos (lhs, lhsDim)];
636
+ ++eq[getPos (rhs, rhsDim)];
637
+ } else {
638
+ llvm_unreachable (" unsupported comparison operator" );
639
+ }
640
+ if (cmp == LE || cmp == GE)
641
+ eq[cstr.getNumDimAndSymbolVars ()] -= 1 ;
642
+
643
+ // Add inequality to the constraint set and check if it made the constraint
644
+ // set empty.
645
+ int64_t ineqPos = cstr.getNumInequalities ();
646
+ cstr.addInequality (eq);
647
+ bool isEmpty = cstr.isEmpty ();
648
+ cstr.removeInequality (ineqPos);
649
+ return isEmpty;
650
+ }
651
+
569
652
FailureOr<bool >
570
653
ValueBoundsConstraintSet::areEqual (Value value1, Value value2,
571
654
std::optional<int64_t > dim1,
0 commit comments