Skip to content

Commit 3876dc9

Browse files
[mlir][Interfaces][WIP] ValueBoundsOpInterface: Variable
1 parent d34a2c2 commit 3876dc9

File tree

20 files changed

+361
-311
lines changed

20 files changed

+361
-311
lines changed

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
5353
/// maximally compose chains of AffineApplyOps.
5454
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
5555

56+
/// Reify a bound for the given variable in terms of SSA values for which
57+
/// `stopCondition` is met.
58+
///
59+
/// By default, lower/equal bounds are closed and upper bounds are open. If
60+
/// `closedUB` is set to "true", upper bounds are also closed.
61+
FailureOr<OpFoldResult>
62+
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
63+
const ValueBoundsConstraintSet::Variable &var,
64+
ValueBoundsConstraintSet::StopConditionFn stopCondition,
65+
bool closedUB = false);
66+
5667
/// Reify a bound for the given index-typed value in terms of SSA values for
5768
/// which `stopCondition` is met. If no stop condition is specified, reify in
5869
/// terms of the operands of the owner op.

mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ enum class BoundType;
2424

2525
namespace arith {
2626

27+
/// Reify a bound for the given variable in terms of SSA values for which
28+
/// `stopCondition` is met.
29+
///
30+
/// By default, lower/equal bounds are closed and upper bounds are open. If
31+
/// `closedUB` is set to "true", upper bounds are also closed.
32+
FailureOr<OpFoldResult>
33+
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
34+
const ValueBoundsConstraintSet::Variable &var,
35+
ValueBoundsConstraintSet::StopConditionFn stopCondition,
36+
bool closedUB = false);
37+
2738
/// Reify a bound for the given index-typed value in terms of SSA values for
2839
/// which `stopCondition` is met. If no stop condition is specified, reify in
2940
/// terms of the operands of the owner op.

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/Value.h"
1616
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1717
#include "llvm/ADT/SetVector.h"
18+
#include "llvm/ADT/SmallVector.h"
1819
#include "llvm/Support/ExtensibleRTTI.h"
1920

2021
#include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
111112
public:
112113
static char ID;
113114

115+
/// A variable that can be added to the constraint set as a "column". The
116+
/// value bounds infrastructure can compute bounds for variables and compare
117+
/// two variables.
118+
///
119+
/// Internally, a variable is represented as an affine map and operands.
120+
class Variable {
121+
public:
122+
/// Construct a variable for an index-typed attribute or SSA value.
123+
Variable(OpFoldResult ofr);
124+
125+
/// Construct a variable for an index-typed SSA value.
126+
Variable(Value indexValue);
127+
128+
/// Construct a variable for a dimension of a shaped value.
129+
Variable(Value shapedValue, int64_t dim);
130+
131+
/// Construct a variable for an index-typed attribute/SSA value or for a
132+
/// dimension of a shaped value. A non-null dimension must be provided if
133+
/// and only if `ofr` is a shaped value.
134+
Variable(OpFoldResult ofr, std::optional<int64_t> dim);
135+
136+
/// Construct a variable for a map and its operands.
137+
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138+
Variable(AffineMap map, ArrayRef<Value> mapOperands);
139+
140+
MLIRContext *getContext() const { return map.getContext(); }
141+
142+
private:
143+
friend class ValueBoundsConstraintSet;
144+
AffineMap map;
145+
ValueDimList mapOperands;
146+
};
147+
114148
/// The stop condition when traversing the backward slice of a shaped value/
115149
/// index-type value. The traversal continues until the stop condition
116150
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
121155
using StopConditionFn = std::function<bool(
122156
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
123157

124-
/// Compute a bound for the given index-typed value or shape dimension size.
125-
/// The computed bound is stored in `resultMap`. The operands of the bound are
126-
/// stored in `mapOperands`. An operand is either an index-type SSA value
127-
/// or a shaped value and a dimension.
158+
/// Compute a bound for the given variable. The computed bound is stored in
159+
/// `resultMap`. The operands of the bound are stored in `mapOperands`. An
160+
/// operand is either an index-type SSA value or a shaped value and a
161+
/// dimension.
128162
///
129-
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
130-
/// is computed in terms of values/dimensions for which `stopCondition`
131-
/// evaluates to "true". To that end, the backward slice (reverse use-def
132-
/// chain) of the given value is visited in a worklist-driven manner and the
133-
/// constraint set is populated according to `ValueBoundsOpInterface` for each
134-
/// visited value.
163+
/// The bound is computed in terms of values/dimensions for which
164+
/// `stopCondition` evaluates to "true". To that end, the backward slice
165+
/// (reverse use-def chain) of the given value is visited in a worklist-driven
166+
/// manner and the constraint set is populated according to
167+
/// `ValueBoundsOpInterface` for each visited value.
135168
///
136169
/// By default, lower/equal bounds are closed and upper bounds are open. If
137170
/// `closedUB` is set to "true", upper bounds are also closed.
138-
static LogicalResult computeBound(AffineMap &resultMap,
139-
ValueDimList &mapOperands,
140-
presburger::BoundType type, Value value,
141-
std::optional<int64_t> dim,
142-
StopConditionFn stopCondition,
143-
bool closedUB = false);
171+
static LogicalResult
172+
computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
173+
presburger::BoundType type, const Variable &var,
174+
StopConditionFn stopCondition, bool closedUB = false);
144175

145176
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
146177
/// computed bound consists of only constant terms and dependent values (or
147178
/// dimension sizes thereof).
148179
static LogicalResult
149180
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
150-
presburger::BoundType type, Value value,
151-
std::optional<int64_t> dim, ValueDimList dependencies,
152-
bool closedUB = false);
181+
presburger::BoundType type, const Variable &var,
182+
ValueDimList dependencies, bool closedUB = false);
153183

154184
/// Compute a bound in that is independent of all values in `independencies`.
155185
///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
161191
/// appear in the computed bound.
162192
static LogicalResult
163193
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
164-
presburger::BoundType type, Value value,
165-
std::optional<int64_t> dim, ValueRange independencies,
166-
bool closedUB = false);
194+
presburger::BoundType type, const Variable &var,
195+
ValueRange independencies, bool closedUB = false);
167196

168-
/// Compute a constant bound for the given affine map, where dims and symbols
169-
/// are bound to the given operands. The affine map must have exactly one
170-
/// result.
197+
/// Compute a constant bound for the given variable.
171198
///
172199
/// This function traverses the backward slice of the given operands in a
173200
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
182209
/// By default, lower/equal bounds are closed and upper bounds are open. If
183210
/// `closedUB` is set to "true", upper bounds are also closed.
184211
static FailureOr<int64_t>
185-
computeConstantBound(presburger::BoundType type, Value value,
186-
std::optional<int64_t> dim = std::nullopt,
212+
computeConstantBound(presburger::BoundType type, const Variable &var,
187213
StopConditionFn stopCondition = nullptr,
188214
bool closedUB = false);
189-
static FailureOr<int64_t> computeConstantBound(
190-
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
191-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
192-
static FailureOr<int64_t> computeConstantBound(
193-
presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
194-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
195215

196216
/// Compute a constant delta between the given two values. Return "failure"
197217
/// if a constant delta could not be determined.
@@ -221,9 +241,8 @@ class ValueBoundsConstraintSet
221241
/// proven. This could be because the specified relation does in fact not hold
222242
/// or because there is not enough information in the constraint set. In other
223243
/// words, if we do not know for sure, this function returns "false".
224-
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
225-
ComparisonOperator cmp, OpFoldResult rhs,
226-
std::optional<int64_t> rhsDim);
244+
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
245+
const Variable &rhs);
227246

228247
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
229248
/// specified relation could not be proven. This could be because the
@@ -233,24 +252,12 @@ class ValueBoundsConstraintSet
233252
///
234253
/// This function keeps traversing the backward slice of lhs/rhs until could
235254
/// prove the relation or until it ran out of IR.
236-
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
237-
ComparisonOperator cmp, OpFoldResult rhs,
238-
std::optional<int64_t> rhsDim);
239-
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
240-
ComparisonOperator cmp, AffineMap rhs,
241-
ValueDimList rhsOperands);
242-
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
243-
ComparisonOperator cmp, AffineMap rhs,
244-
ArrayRef<Value> rhsOperands);
245-
246-
/// Compute whether the given values/dimensions are equal. Return "failure" if
255+
static bool compare(const Variable &lhs, ComparisonOperator cmp,
256+
const Variable &rhs);
257+
258+
/// Compute whether the given variables are equal. Return "failure" if
247259
/// equality could not be determined.
248-
///
249-
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
250-
/// index-typed.
251-
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
252-
std::optional<int64_t> dim1 = std::nullopt,
253-
std::optional<int64_t> dim2 = std::nullopt);
260+
static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
254261

255262
/// Return "true" if the given slices are guaranteed to be overlapping.
256263
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +324,6 @@ class ValueBoundsConstraintSet
317324
///
318325
/// This function does not analyze any IR and does not populate any additional
319326
/// constraints.
320-
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
321-
ComparisonOperator cmp, OpFoldResult rhs,
322-
std::optional<int64_t> rhsDim);
323327
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
324328

325329
/// Given an affine map with a single result (and map operands), add a new
@@ -374,13 +378,16 @@ class ValueBoundsConstraintSet
374378
/// constraint system. Return the position of the new column. Any operands
375379
/// that were not analyzed yet are put on the worklist.
376380
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
381+
int64_t insert(const Variable &var, bool isSymbol = true);
377382

378383
/// Project out the given column in the constraint set.
379384
void projectOut(int64_t pos);
380385

381386
/// Project out all columns for which the condition holds.
382387
void projectOut(function_ref<bool(ValueDim)> condition);
383388

389+
void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
390+
384391
/// Mapping of columns to values/shape dimensions.
385392
SmallVector<std::optional<ValueDim>> positionToValueDim;
386393
/// Reverse mapping of values/shape dimensions to columns.

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
120120
mapOperands.push_back(value1);
121121
mapOperands.push_back(value2);
122122
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
123-
ValueDimList valueDims;
124-
for (Value v : mapOperands)
125-
valueDims.push_back({v, std::nullopt});
126123
return ValueBoundsConstraintSet::computeConstantBound(
127-
presburger::BoundType::EQ, map, valueDims);
124+
presburger::BoundType::EQ,
125+
ValueBoundsConstraintSet::Variable(map, mapOperands));
128126
}

mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@
1616
using namespace mlir;
1717
using namespace mlir::affine;
1818

19-
static FailureOr<OpFoldResult>
20-
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
21-
Value value, std::optional<int64_t> dim,
22-
ValueBoundsConstraintSet::StopConditionFn stopCondition,
23-
bool closedUB) {
19+
FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
20+
OpBuilder &b, Location loc, presburger::BoundType type,
21+
const ValueBoundsConstraintSet::Variable &var,
22+
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
2423
// Compute bound.
2524
AffineMap boundMap;
2625
ValueDimList mapOperands;
2726
if (failed(ValueBoundsConstraintSet::computeBound(
28-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
27+
boundMap, mapOperands, type, var, stopCondition, closedUB)))
2928
return failure();
3029

3130
// Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
9392
// the owner of `value`.
9493
return v != value;
9594
};
96-
return reifyValueBound(b, loc, type, value, dim,
95+
return reifyValueBound(b, loc, type, {value, dim},
9796
stopCondition ? stopCondition : reifyToOperands,
9897
closedUB);
9998
}
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
105104
ValueBoundsConstraintSet &cstr) {
106105
return v != value;
107106
};
108-
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
107+
return reifyValueBound(b, loc, type, value,
109108
stopCondition ? stopCondition : reifyToOperands,
110109
closedUB);
111110
}

mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ struct SelectOpInterface
107107
// If trueValue <= falseValue:
108108
// * result <= falseValue
109109
// * result >= trueValue
110-
if (cstr.compare(trueValue, dim,
110+
if (cstr.compare(/*lhs=*/{trueValue, dim},
111111
ValueBoundsConstraintSet::ComparisonOperator::LE,
112-
falseValue, dim)) {
112+
/*rhs=*/{falseValue, dim})) {
113113
if (dim) {
114114
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
115115
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
121121
// If falseValue <= trueValue:
122122
// * result <= trueValue
123123
// * result >= falseValue
124-
if (cstr.compare(falseValue, dim,
124+
if (cstr.compare(/*lhs=*/{falseValue, dim},
125125
ValueBoundsConstraintSet::ComparisonOperator::LE,
126-
trueValue, dim)) {
126+
/*rhs=*/{trueValue, dim})) {
127127
if (dim) {
128128
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
129129
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
449449
return failure();
450450

451451
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
452-
presburger::BoundType::UB, in, /*dim=*/std::nullopt,
452+
presburger::BoundType::UB, in,
453453
/*stopCondition=*/nullptr, /*closedUB=*/true);
454454
if (failed(ub))
455455
return failure();

mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
6161
return buildExpr(map.getResult(0));
6262
}
6363

64-
static FailureOr<OpFoldResult>
65-
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
66-
Value value, std::optional<int64_t> dim,
67-
ValueBoundsConstraintSet::StopConditionFn stopCondition,
68-
bool closedUB) {
64+
FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
65+
OpBuilder &b, Location loc, presburger::BoundType type,
66+
const ValueBoundsConstraintSet::Variable &var,
67+
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
6968
// Compute bound.
7069
AffineMap boundMap;
7170
ValueDimList mapOperands;
7271
if (failed(ValueBoundsConstraintSet::computeBound(
73-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
72+
boundMap, mapOperands, type, var, stopCondition, closedUB)))
7473
return failure();
7574

7675
// Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
128127
// the owner of `value`.
129128
return v != value;
130129
};
131-
return reifyValueBound(b, loc, type, value, dim,
130+
return reifyValueBound(b, loc, type, {value, dim},
132131
stopCondition ? stopCondition : reifyToOperands,
133132
closedUB);
134133
}
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
140139
ValueBoundsConstraintSet &cstr) {
141140
return v != value;
142141
};
143-
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
142+
return reifyValueBound(b, loc, type, value,
144143
stopCondition ? stopCondition : reifyToOperands,
145144
closedUB);
146145
}

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
7272
// Otherwise, try to compute a constant upper bound for the size value.
7373
FailureOr<int64_t> upperBound =
7474
ValueBoundsConstraintSet::computeConstantBound(
75-
presburger::BoundType::UB, opOperand->get(),
76-
/*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
75+
presburger::BoundType::UB,
76+
{opOperand->get(),
77+
/*dim=*/i},
78+
/*stopCondition=*/nullptr, /*closedUB=*/true);
7779
if (failed(upperBound)) {
7880
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
7981
return failure();

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
257257
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
258258
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
259259
} else {
260-
Value materializedSize =
261-
getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
262260
FailureOr<int64_t> upperBound =
263261
ValueBoundsConstraintSet::computeConstantBound(
264-
presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
262+
presburger::BoundType::UB, rangeValue.size,
265263
/*stopCondition=*/nullptr, /*closedUB=*/true);
266264
size = failed(upperBound)
267-
? materializedSize
265+
? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
268266
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
269267
}
270268
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");

0 commit comments

Comments
 (0)