Skip to content

Commit 79734b8

Browse files
[mlir][Interfaces][WIP] ValueBoundsOpInterface: Variable
1 parent 21265f6 commit 79734b8

File tree

15 files changed

+247
-279
lines changed

15 files changed

+247
-279
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/Value.h"
1616
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1717
#include "llvm/ADT/SetVector.h"
18+
#include "llvm/ADT/SmallVector.h"
1819
#include "llvm/Support/ExtensibleRTTI.h"
1920

2021
#include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
111112
public:
112113
static char ID;
113114

115+
/// A variable that can be added to the constraint set as a "column". The
116+
/// value bounds infrastructure can compute bounds for variables and compare
117+
/// two variables.
118+
///
119+
/// Internally, a variable is represented as an affine map and operands.
120+
class Variable {
121+
public:
122+
/// Construct a variable for an index-typed attribute or SSA value.
123+
Variable(OpFoldResult ofr);
124+
125+
/// Construct a variable for an index-typed SSA value.
126+
Variable(Value indexValue);
127+
128+
/// Construct a variable for a dimension of a shaped value.
129+
Variable(Value shapedValue, int64_t dim);
130+
131+
/// Construct a variable for an index-typed attribute/SSA value or for a
132+
/// dimension of a shaped value. A non-null dimension must be provided if
133+
/// and only if `ofr` is a shaped value.
134+
Variable(OpFoldResult ofr, std::optional<int64_t> dim);
135+
136+
/// Construct a variable for a map and its operands.
137+
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138+
Variable(AffineMap map, ArrayRef<Value> mapOperands);
139+
140+
MLIRContext *getContext() const { return map.getContext(); }
141+
142+
private:
143+
friend class ValueBoundsConstraintSet;
144+
AffineMap map;
145+
ValueDimList mapOperands;
146+
};
147+
114148
/// The stop condition when traversing the backward slice of a shaped value/
115149
/// index-type value. The traversal continues until the stop condition
116150
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
121155
using StopConditionFn = std::function<bool(
122156
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
123157

124-
/// Compute a bound for the given index-typed value or shape dimension size.
125-
/// The computed bound is stored in `resultMap`. The operands of the bound are
126-
/// stored in `mapOperands`. An operand is either an index-type SSA value
127-
/// or a shaped value and a dimension.
158+
/// Compute a bound for the given variable. The computed bound is stored in
159+
/// `resultMap`. The operands of the bound are stored in `mapOperands`. An
160+
/// operand is either an index-type SSA value or a shaped value and a
161+
/// dimension.
128162
///
129-
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
130-
/// is computed in terms of values/dimensions for which `stopCondition`
131-
/// evaluates to "true". To that end, the backward slice (reverse use-def
132-
/// chain) of the given value is visited in a worklist-driven manner and the
133-
/// constraint set is populated according to `ValueBoundsOpInterface` for each
134-
/// visited value.
163+
/// The bound is computed in terms of values/dimensions for which
164+
/// `stopCondition` evaluates to "true". To that end, the backward slice
165+
/// (reverse use-def chain) of the given value is visited in a worklist-driven
166+
/// manner and the constraint set is populated according to
167+
/// `ValueBoundsOpInterface` for each visited value.
135168
///
136169
/// By default, lower/equal bounds are closed and upper bounds are open. If
137170
/// `closedUB` is set to "true", upper bounds are also closed.
138-
static LogicalResult computeBound(AffineMap &resultMap,
139-
ValueDimList &mapOperands,
140-
presburger::BoundType type, Value value,
141-
std::optional<int64_t> dim,
142-
StopConditionFn stopCondition,
143-
bool closedUB = false);
171+
static LogicalResult
172+
computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
173+
presburger::BoundType type, const Variable &var,
174+
StopConditionFn stopCondition, bool closedUB = false);
144175

145176
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
146177
/// computed bound consists of only constant terms and dependent values (or
147178
/// dimension sizes thereof).
148179
static LogicalResult
149180
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
150-
presburger::BoundType type, Value value,
151-
std::optional<int64_t> dim, ValueDimList dependencies,
152-
bool closedUB = false);
181+
presburger::BoundType type, const Variable &var,
182+
ValueDimList dependencies, bool closedUB = false);
153183

154184
/// Compute a bound in that is independent of all values in `independencies`.
155185
///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
161191
/// appear in the computed bound.
162192
static LogicalResult
163193
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
164-
presburger::BoundType type, Value value,
165-
std::optional<int64_t> dim, ValueRange independencies,
166-
bool closedUB = false);
194+
presburger::BoundType type, const Variable &var,
195+
ValueRange independencies, bool closedUB = false);
167196

168-
/// Compute a constant bound for the given affine map, where dims and symbols
169-
/// are bound to the given operands. The affine map must have exactly one
170-
/// result.
197+
/// Compute a constant bound for the given variable.
171198
///
172199
/// This function traverses the backward slice of the given operands in a
173200
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
182209
/// By default, lower/equal bounds are closed and upper bounds are open. If
183210
/// `closedUB` is set to "true", upper bounds are also closed.
184211
static FailureOr<int64_t>
185-
computeConstantBound(presburger::BoundType type, Value value,
186-
std::optional<int64_t> dim = std::nullopt,
212+
computeConstantBound(presburger::BoundType type, const Variable &var,
187213
StopConditionFn stopCondition = nullptr,
188214
bool closedUB = false);
189-
static FailureOr<int64_t> computeConstantBound(
190-
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
191-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
192-
static FailureOr<int64_t> computeConstantBound(
193-
presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
194-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
195215

196216
/// Compute a constant delta between the given two values. Return "failure"
197217
/// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
221241
/// proven. This could be because the specified relation does in fact not hold
222242
/// or because there is not enough information in the constraint set. In other
223243
/// words, if we do not know for sure, this function returns "false".
224-
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
225-
ComparisonOperator cmp, OpFoldResult rhs,
226-
std::optional<int64_t> rhsDim);
244+
bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
227245

228246
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
229247
/// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
233251
///
234252
/// This function keeps traversing the backward slice of lhs/rhs until could
235253
/// prove the relation or until it ran out of IR.
236-
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
237-
ComparisonOperator cmp, OpFoldResult rhs,
238-
std::optional<int64_t> rhsDim);
239-
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
240-
ComparisonOperator cmp, AffineMap rhs,
241-
ValueDimList rhsOperands);
242-
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
243-
ComparisonOperator cmp, AffineMap rhs,
244-
ArrayRef<Value> rhsOperands);
245-
246-
/// Compute whether the given values/dimensions are equal. Return "failure" if
254+
static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
255+
256+
/// Compute whether the given variables are equal. Return "failure" if
247257
/// equality could not be determined.
248-
///
249-
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
250-
/// index-typed.
251-
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
252-
std::optional<int64_t> dim1 = std::nullopt,
253-
std::optional<int64_t> dim2 = std::nullopt);
258+
static FailureOr<bool> areEqual(Variable var1, Variable var2);
254259

255260
/// Return "true" if the given slices are guaranteed to be overlapping.
256261
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
317322
///
318323
/// This function does not analyze any IR and does not populate any additional
319324
/// constraints.
320-
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
321-
ComparisonOperator cmp, OpFoldResult rhs,
322-
std::optional<int64_t> rhsDim);
323325
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
324326

325327
/// Given an affine map with a single result (and map operands), add a new
@@ -374,13 +376,16 @@ class ValueBoundsConstraintSet
374376
/// constraint system. Return the position of the new column. Any operands
375377
/// that were not analyzed yet are put on the worklist.
376378
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
379+
int64_t insert(const Variable &var, bool isSymbol = true);
377380

378381
/// Project out the given column in the constraint set.
379382
void projectOut(int64_t pos);
380383

381384
/// Project out all columns for which the condition holds.
382385
void projectOut(function_ref<bool(ValueDim)> condition);
383386

387+
void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
388+
384389
/// Mapping of columns to values/shape dimensions.
385390
SmallVector<std::optional<ValueDim>> positionToValueDim;
386391
/// Reverse mapping of values/shape dimensions to columns.

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
120120
mapOperands.push_back(value1);
121121
mapOperands.push_back(value2);
122122
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
123-
ValueDimList valueDims;
124-
for (Value v : mapOperands)
125-
valueDims.push_back({v, std::nullopt});
126123
return ValueBoundsConstraintSet::computeConstantBound(
127-
presburger::BoundType::EQ, map, valueDims);
124+
presburger::BoundType::EQ,
125+
ValueBoundsConstraintSet::Variable(map, mapOperands));
128126
}

mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
2525
AffineMap boundMap;
2626
ValueDimList mapOperands;
2727
if (failed(ValueBoundsConstraintSet::computeBound(
28-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
28+
boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
2929
return failure();
3030

3131
// Reify bound.

mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ struct SelectOpInterface
107107
// If trueValue <= falseValue:
108108
// * result <= falseValue
109109
// * result >= trueValue
110-
if (cstr.compare(trueValue, dim,
110+
if (cstr.compare(/*lhs=*/{trueValue, dim},
111111
ValueBoundsConstraintSet::ComparisonOperator::LE,
112-
falseValue, dim)) {
112+
/*rhs=*/{falseValue, dim})) {
113113
if (dim) {
114114
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
115115
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
121121
// If falseValue <= trueValue:
122122
// * result <= trueValue
123123
// * result >= falseValue
124-
if (cstr.compare(falseValue, dim,
124+
if (cstr.compare(/*lhs=*/{falseValue, dim},
125125
ValueBoundsConstraintSet::ComparisonOperator::LE,
126-
trueValue, dim)) {
126+
/*rhs=*/{trueValue, dim})) {
127127
if (dim) {
128128
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
129129
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
449449
return failure();
450450

451451
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
452-
presburger::BoundType::UB, in, /*dim=*/std::nullopt,
452+
presburger::BoundType::UB, in,
453453
/*stopCondition=*/nullptr, /*closedUB=*/true);
454454
if (failed(ub))
455455
return failure();

mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
7070
AffineMap boundMap;
7171
ValueDimList mapOperands;
7272
if (failed(ValueBoundsConstraintSet::computeBound(
73-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
73+
boundMap, mapOperands, type,
74+
ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
75+
closedUB)))
7476
return failure();
7577

7678
// Materialize tensor.dim/memref.dim ops.

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
7272
// Otherwise, try to compute a constant upper bound for the size value.
7373
FailureOr<int64_t> upperBound =
7474
ValueBoundsConstraintSet::computeConstantBound(
75-
presburger::BoundType::UB, opOperand->get(),
76-
/*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
75+
presburger::BoundType::UB,
76+
{opOperand->get(),
77+
/*dim=*/i},
78+
/*stopCondition=*/nullptr, /*closedUB=*/true);
7779
if (failed(upperBound)) {
7880
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
7981
return failure();

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
257257
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
258258
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
259259
} else {
260-
Value materializedSize =
261-
getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
262260
FailureOr<int64_t> upperBound =
263261
ValueBoundsConstraintSet::computeConstantBound(
264-
presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
262+
presburger::BoundType::UB, rangeValue.size,
265263
/*stopCondition=*/nullptr, /*closedUB=*/true);
266264
size = failed(upperBound)
267-
? materializedSize
265+
? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
268266
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
269267
}
270268
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");

mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
2323
ValueRange independencies) {
2424
if (ofr.is<Attribute>())
2525
return ofr;
26-
Value value = ofr.get<Value>();
2726
AffineMap boundMap;
2827
ValueDimList mapOperands;
2928
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
30-
boundMap, mapOperands, presburger::BoundType::UB, value,
31-
/*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
29+
boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
30+
/*closedUB=*/true)))
3231
return failure();
3332
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
3433
}

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ struct ForOpInterface
6161
// An EQ constraint can be added if the yielded value (dimension size)
6262
// equals the corresponding block argument (dimension size).
6363
if (cstr.populateAndCompare(
64-
yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
65-
iterArg, dim)) {
64+
/*lhs=*/{yieldedValue, dim},
65+
ValueBoundsConstraintSet::ComparisonOperator::EQ,
66+
/*rhs=*/{iterArg, dim})) {
6667
if (dim.has_value()) {
6768
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
6869
} else {
69-
cstr.bound(value) == initArg;
70+
cstr.bound(value) == cstr.getExpr(initArg);
7071
}
7172
}
7273
}
@@ -113,8 +114,9 @@ struct IfOpInterface
113114
// * result <= elseValue
114115
// * result >= thenValue
115116
if (cstr.populateAndCompare(
116-
thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
117-
elseValue, dim)) {
117+
/*lhs=*/{thenValue, dim},
118+
ValueBoundsConstraintSet::ComparisonOperator::LE,
119+
/*rhs=*/{elseValue, dim})) {
118120
if (dim) {
119121
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
120122
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
127129
// * result <= thenValue
128130
// * result >= elseValue
129131
if (cstr.populateAndCompare(
130-
elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
131-
thenValue, dim)) {
132+
/*lhs=*/{elseValue, dim},
133+
ValueBoundsConstraintSet::ComparisonOperator::LE,
134+
/*rhs=*/{thenValue, dim})) {
132135
if (dim) {
133136
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
134137
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
289289

290290
info.isAlignedToInnerTileSize = false;
291291
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
292-
presburger::BoundType::UB,
293-
getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
292+
presburger::BoundType::UB, tileSize,
294293
/*stopCondition=*/nullptr, /*closedUB=*/true);
295294
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
296295
if (!failed(cstSize) && cstInnerSize) {

mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
2828
ValueDimList mapOperands;
2929
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
3030
boundMap, mapOperands, presburger::BoundType::UB, value,
31-
/*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
31+
independencies,
32+
/*closedUB=*/true)))
3233
return failure();
3334
return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
3435
}

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
154154
continue;
155155
}
156156
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
157-
op.getSource(), op.getResult(), srcDim, resultDim);
157+
{op.getSource(), srcDim}, {op.getResult(), resultDim});
158158
if (failed(equalDimSize) || !*equalDimSize)
159159
return false;
160160
++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
178178
continue;
179179
}
180180
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
181-
op.getSource(), op.getResult(), dim, resultDim);
181+
{op.getSource(), dim}, {op.getResult(), resultDim});
182182
if (failed(equalDimSize) || !*equalDimSize)
183183
return false;
184184
++resultDim;

0 commit comments

Comments
 (0)