Skip to content

Commit ed12ff5

Browse files
[mlir][Interfaces][WIP] ValueBoundsOpInterface: Variable
1 parent 4a019ca commit ed12ff5

File tree

15 files changed

+312
-275
lines changed

15 files changed

+312
-275
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/Value.h"
1616
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1717
#include "llvm/ADT/SetVector.h"
18+
#include "llvm/ADT/SmallVector.h"
1819
#include "llvm/Support/ExtensibleRTTI.h"
1920

2021
#include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
111112
public:
112113
static char ID;
113114

115+
/// A variable that can be added to the constraint set as a "column". The
116+
/// value bounds infrastructure can compute bounds for variables and compare
117+
/// two variables.
118+
///
119+
/// Internally, a variable is represented as an affine map and operands.
120+
class Variable {
121+
public:
122+
/// Construct a variable for an index-typed attribute or SSA value.
123+
Variable(OpFoldResult ofr);
124+
125+
/// Construct a variable for an index-typed SSA value.
126+
Variable(Value indexValue);
127+
128+
/// Construct a variable for a dimension of a shaped value.
129+
Variable(Value shapedValue, int64_t dim);
130+
131+
/// Construct a variable for an index-typed attribute/SSA value or for a
132+
/// dimension of a shaped value. A non-null dimension must be provided if
133+
/// and only if `ofr` is a shaped value.
134+
Variable(OpFoldResult ofr, std::optional<int64_t> dim);
135+
136+
/// Construct a variable for a map and its operands.
137+
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138+
Variable(AffineMap map, ArrayRef<Value> mapOperands);
139+
140+
MLIRContext *getContext() const { return map.getContext(); }
141+
142+
private:
143+
friend class ValueBoundsConstraintSet;
144+
AffineMap map;
145+
ValueDimList mapOperands;
146+
};
147+
114148
/// The stop condition when traversing the backward slice of a shaped value/
115149
/// index-type value. The traversal continues until the stop condition
116150
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
121155
using StopConditionFn = std::function<bool(
122156
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
123157

124-
/// Compute a bound for the given index-typed value or shape dimension size.
125-
/// The computed bound is stored in `resultMap`. The operands of the bound are
126-
/// stored in `mapOperands`. An operand is either an index-type SSA value
127-
/// or a shaped value and a dimension.
158+
/// Compute a bound for the given variable. The computed bound is stored in
159+
/// `resultMap`. The operands of the bound are stored in `mapOperands`. An
160+
/// operand is either an index-type SSA value or a shaped value and a
161+
/// dimension.
128162
///
129-
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
130-
/// is computed in terms of values/dimensions for which `stopCondition`
131-
/// evaluates to "true". To that end, the backward slice (reverse use-def
132-
/// chain) of the given value is visited in a worklist-driven manner and the
133-
/// constraint set is populated according to `ValueBoundsOpInterface` for each
134-
/// visited value.
163+
/// The bound is computed in terms of values/dimensions for which
164+
/// `stopCondition` evaluates to "true". To that end, the backward slice
165+
/// (reverse use-def chain) of the given value is visited in a worklist-driven
166+
/// manner and the constraint set is populated according to
167+
/// `ValueBoundsOpInterface` for each visited value.
135168
///
136169
/// By default, lower/equal bounds are closed and upper bounds are open. If
137170
/// `closedUB` is set to "true", upper bounds are also closed.
138-
static LogicalResult computeBound(AffineMap &resultMap,
139-
ValueDimList &mapOperands,
140-
presburger::BoundType type, Value value,
141-
std::optional<int64_t> dim,
142-
StopConditionFn stopCondition,
143-
bool closedUB = false);
171+
static LogicalResult
172+
computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
173+
presburger::BoundType type, const Variable &var,
174+
StopConditionFn stopCondition, bool closedUB = false);
144175

145176
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
146177
/// computed bound consists of only constant terms and dependent values (or
147178
/// dimension sizes thereof).
148179
static LogicalResult
149180
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
150-
presburger::BoundType type, Value value,
151-
std::optional<int64_t> dim, ValueDimList dependencies,
152-
bool closedUB = false);
181+
presburger::BoundType type, const Variable &var,
182+
ValueDimList dependencies, bool closedUB = false);
153183

154184
/// Compute a bound in that is independent of all values in `independencies`.
155185
///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
161191
/// appear in the computed bound.
162192
static LogicalResult
163193
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
164-
presburger::BoundType type, Value value,
165-
std::optional<int64_t> dim, ValueRange independencies,
166-
bool closedUB = false);
194+
presburger::BoundType type, const Variable &var,
195+
ValueRange independencies, bool closedUB = false);
167196

168-
/// Compute a constant bound for the given affine map, where dims and symbols
169-
/// are bound to the given operands. The affine map must have exactly one
170-
/// result.
197+
/// Compute a constant bound for the given variable.
171198
///
172199
/// This function traverses the backward slice of the given operands in a
173200
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
182209
/// By default, lower/equal bounds are closed and upper bounds are open. If
183210
/// `closedUB` is set to "true", upper bounds are also closed.
184211
static FailureOr<int64_t>
185-
computeConstantBound(presburger::BoundType type, Value value,
186-
std::optional<int64_t> dim = std::nullopt,
212+
computeConstantBound(presburger::BoundType type, const Variable &var,
187213
StopConditionFn stopCondition = nullptr,
188214
bool closedUB = false);
189-
static FailureOr<int64_t> computeConstantBound(
190-
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
191-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
192-
static FailureOr<int64_t> computeConstantBound(
193-
presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
194-
StopConditionFn stopCondition = nullptr, bool closedUB = false);
195215

196216
/// Compute a constant delta between the given two values. Return "failure"
197217
/// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
221241
/// proven. This could be because the specified relation does in fact not hold
222242
/// or because there is not enough information in the constraint set. In other
223243
/// words, if we do not know for sure, this function returns "false".
224-
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
225-
ComparisonOperator cmp, OpFoldResult rhs,
226-
std::optional<int64_t> rhsDim);
244+
bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
227245

228246
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
229247
/// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
233251
///
234252
/// This function keeps traversing the backward slice of lhs/rhs until could
235253
/// prove the relation or until it ran out of IR.
236-
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
237-
ComparisonOperator cmp, OpFoldResult rhs,
238-
std::optional<int64_t> rhsDim);
239-
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
240-
ComparisonOperator cmp, AffineMap rhs,
241-
ValueDimList rhsOperands);
242-
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
243-
ComparisonOperator cmp, AffineMap rhs,
244-
ArrayRef<Value> rhsOperands);
245-
246-
/// Compute whether the given values/dimensions are equal. Return "failure" if
254+
static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
255+
256+
/// Compute whether the given variables are equal. Return "failure" if
247257
/// equality could not be determined.
248-
///
249-
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
250-
/// index-typed.
251-
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
252-
std::optional<int64_t> dim1 = std::nullopt,
253-
std::optional<int64_t> dim2 = std::nullopt);
258+
static FailureOr<bool> areEqual(Variable var1, Variable var2);
254259

255260
/// Return "true" if the given slices are guaranteed to be overlapping.
256261
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
317322
///
318323
/// This function does not analyze any IR and does not populate any additional
319324
/// constraints.
320-
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
321-
ComparisonOperator cmp, OpFoldResult rhs,
322-
std::optional<int64_t> rhsDim);
323325
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
324326

325327
/// Given an affine map with a single result (and map operands), add a new
@@ -374,13 +376,16 @@ class ValueBoundsConstraintSet
374376
/// constraint system. Return the position of the new column. Any operands
375377
/// that were not analyzed yet are put on the worklist.
376378
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
379+
int64_t insert(const Variable &var, bool isSymbol = true);
377380

378381
/// Project out the given column in the constraint set.
379382
void projectOut(int64_t pos);
380383

381384
/// Project out all columns for which the condition holds.
382385
void projectOut(function_ref<bool(ValueDim)> condition);
383386

387+
void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
388+
384389
/// Mapping of columns to values/shape dimensions.
385390
SmallVector<std::optional<ValueDim>> positionToValueDim;
386391
/// Reverse mapping of values/shape dimensions to columns.

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
120120
mapOperands.push_back(value1);
121121
mapOperands.push_back(value2);
122122
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
123-
ValueDimList valueDims;
124-
for (Value v : mapOperands)
125-
valueDims.push_back({v, std::nullopt});
126123
return ValueBoundsConstraintSet::computeConstantBound(
127-
presburger::BoundType::EQ, map, valueDims);
124+
presburger::BoundType::EQ,
125+
ValueBoundsConstraintSet::Variable(map, mapOperands));
128126
}

mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
2525
AffineMap boundMap;
2626
ValueDimList mapOperands;
2727
if (failed(ValueBoundsConstraintSet::computeBound(
28-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
28+
boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
2929
return failure();
3030

3131
// Reify bound.

mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,75 @@ struct MulIOpInterface
7575
}
7676
};
7777

78+
struct SelectOpInterface
79+
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80+
SelectOp> {
81+
82+
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
83+
ValueBoundsConstraintSet &cstr) {
84+
Value value = selectOp.getResult();
85+
Value condition = selectOp.getCondition();
86+
Value trueValue = selectOp.getTrueValue();
87+
Value falseValue = selectOp.getFalseValue();
88+
89+
if (isa<ShapedType>(condition.getType())) {
90+
// If the condition is a shaped type, the condition is applied
91+
// element-wise. All three operands must have the same shape.
92+
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
93+
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
94+
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
95+
return;
96+
}
97+
98+
// Populate constraints for the true/false values (and all values on the
99+
// backward slice, as long as the current stop condition is not satisfied).
100+
cstr.populateConstraints(trueValue, dim);
101+
cstr.populateConstraints(falseValue, dim);
102+
auto boundsBuilder = cstr.bound(value);
103+
if (dim)
104+
boundsBuilder[*dim];
105+
106+
// Compare yielded values.
107+
// If trueValue <= falseValue:
108+
// * result <= falseValue
109+
// * result >= trueValue
110+
if (cstr.compare(/*lhs=*/{trueValue, dim},
111+
ValueBoundsConstraintSet::ComparisonOperator::LE,
112+
/*rhs=*/{falseValue, dim})) {
113+
if (dim) {
114+
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
115+
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
116+
} else {
117+
cstr.bound(value) >= trueValue;
118+
cstr.bound(value) <= falseValue;
119+
}
120+
}
121+
// If falseValue <= trueValue:
122+
// * result <= trueValue
123+
// * result >= falseValue
124+
if (cstr.compare(/*lhs=*/{falseValue, dim},
125+
ValueBoundsConstraintSet::ComparisonOperator::LE,
126+
/*rhs=*/{trueValue, dim})) {
127+
if (dim) {
128+
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
129+
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
130+
} else {
131+
cstr.bound(value) >= falseValue;
132+
cstr.bound(value) <= trueValue;
133+
}
134+
}
135+
}
136+
137+
void populateBoundsForIndexValue(Operation *op, Value value,
138+
ValueBoundsConstraintSet &cstr) const {
139+
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
140+
}
141+
142+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
143+
ValueBoundsConstraintSet &cstr) const {
144+
populateBounds(cast<SelectOp>(op), dim, cstr);
145+
}
146+
};
78147
} // namespace
79148
} // namespace arith
80149
} // namespace mlir

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
449449
return failure();
450450

451451
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
452-
presburger::BoundType::UB, in, /*dim=*/std::nullopt,
452+
presburger::BoundType::UB, in,
453453
/*stopCondition=*/nullptr, /*closedUB=*/true);
454454
if (failed(ub))
455455
return failure();

mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
7070
AffineMap boundMap;
7171
ValueDimList mapOperands;
7272
if (failed(ValueBoundsConstraintSet::computeBound(
73-
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
73+
boundMap, mapOperands, type,
74+
ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
75+
closedUB)))
7476
return failure();
7577

7678
// Materialize tensor.dim/memref.dim ops.

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
7272
// Otherwise, try to compute a constant upper bound for the size value.
7373
FailureOr<int64_t> upperBound =
7474
ValueBoundsConstraintSet::computeConstantBound(
75-
presburger::BoundType::UB, opOperand->get(),
76-
/*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
75+
presburger::BoundType::UB,
76+
{opOperand->get(),
77+
/*dim=*/i},
78+
/*stopCondition=*/nullptr, /*closedUB=*/true);
7779
if (failed(upperBound)) {
7880
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
7981
return failure();

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
257257
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
258258
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
259259
} else {
260-
Value materializedSize =
261-
getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
262260
FailureOr<int64_t> upperBound =
263261
ValueBoundsConstraintSet::computeConstantBound(
264-
presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
262+
presburger::BoundType::UB, rangeValue.size,
265263
/*stopCondition=*/nullptr, /*closedUB=*/true);
266264
size = failed(upperBound)
267-
? materializedSize
265+
? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
268266
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
269267
}
270268
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");

mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
2323
ValueRange independencies) {
2424
if (ofr.is<Attribute>())
2525
return ofr;
26-
Value value = ofr.get<Value>();
2726
AffineMap boundMap;
2827
ValueDimList mapOperands;
2928
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
30-
boundMap, mapOperands, presburger::BoundType::UB, value,
31-
/*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
29+
boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
30+
/*closedUB=*/true)))
3231
return failure();
3332
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
3433
}

0 commit comments

Comments
 (0)