Skip to content

[mlir][Interfaces] Variable abstraction for ValueBoundsOpInterface #87980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);

/// Reify a bound for the given variable in terms of SSA values for which
/// `stopCondition` is met.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB = false);

/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ enum class BoundType;

namespace arith {

/// Reify a bound for the given variable in terms of SSA values for which
/// `stopCondition` is met.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB = false);

/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
Expand Down
119 changes: 63 additions & 56 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ExtensibleRTTI.h"

#include <queue>
Expand Down Expand Up @@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
public:
static char ID;

/// A variable that can be added to the constraint set as a "column". The
/// value bounds infrastructure can compute bounds for variables and compare
/// two variables.
///
/// Internally, a variable is represented as an affine map and operands.
class Variable {
public:
/// Construct a variable for an index-typed attribute or SSA value.
Variable(OpFoldResult ofr);

/// Construct a variable for an index-typed SSA value.
Variable(Value indexValue);

/// Construct a variable for a dimension of a shaped value.
Variable(Value shapedValue, int64_t dim);

/// Construct a variable for an index-typed attribute/SSA value or for a
/// dimension of a shaped value. A non-null dimension must be provided if
/// and only if `ofr` is a shaped value.
Variable(OpFoldResult ofr, std::optional<int64_t> dim);

/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
Variable(AffineMap map, ArrayRef<Value> mapOperands);

MLIRContext *getContext() const { return map.getContext(); }

private:
friend class ValueBoundsConstraintSet;
AffineMap map;
ValueDimList mapOperands;
};

/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
Expand All @@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;

/// Compute a bound for the given index-typed value or shape dimension size.
/// The computed bound is stored in `resultMap`. The operands of the bound are
/// stored in `mapOperands`. An operand is either an index-type SSA value
/// or a shaped value and a dimension.
/// Compute a bound for the given variable. The computed bound is stored in
/// `resultMap`. The operands of the bound are stored in `mapOperands`. An
/// operand is either an index-type SSA value or a shaped value and a
/// dimension.
///
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
/// is computed in terms of values/dimensions for which `stopCondition`
/// evaluates to "true". To that end, the backward slice (reverse use-def
/// chain) of the given value is visited in a worklist-driven manner and the
/// constraint set is populated according to `ValueBoundsOpInterface` for each
/// visited value.
/// The bound is computed in terms of values/dimensions for which
/// `stopCondition` evaluates to "true". To that end, the backward slice
/// (reverse use-def chain) of the given value is visited in a worklist-driven
/// manner and the constraint set is populated according to
/// `ValueBoundsOpInterface` for each visited value.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static LogicalResult computeBound(AffineMap &resultMap,
ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim,
StopConditionFn stopCondition,
bool closedUB = false);
static LogicalResult
computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition, bool closedUB = false);

/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim, ValueDimList dependencies,
bool closedUB = false);
presburger::BoundType type, const Variable &var,
ValueDimList dependencies, bool closedUB = false);

/// Compute a bound in that is independent of all values in `independencies`.
///
Expand All @@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
/// appear in the computed bound.
static LogicalResult
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim, ValueRange independencies,
bool closedUB = false);
presburger::BoundType type, const Variable &var,
ValueRange independencies, bool closedUB = false);

/// Compute a constant bound for the given affine map, where dims and symbols
/// are bound to the given operands. The affine map must have exactly one
/// result.
/// Compute a constant bound for the given variable.
///
/// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
Expand All @@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
computeConstantBound(presburger::BoundType type, Value value,
std::optional<int64_t> dim = std::nullopt,
computeConstantBound(presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
static FailureOr<int64_t> computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
StopConditionFn stopCondition = nullptr, bool closedUB = false);
static FailureOr<int64_t> computeConstantBound(
presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
StopConditionFn stopCondition = nullptr, bool closedUB = false);

/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
Expand Down Expand Up @@ -221,9 +241,8 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);

/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
Expand All @@ -233,24 +252,12 @@ class ValueBoundsConstraintSet
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
ComparisonOperator cmp, AffineMap rhs,
ValueDimList rhsOperands);
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
ComparisonOperator cmp, AffineMap rhs,
ArrayRef<Value> rhsOperands);

/// Compute whether the given values/dimensions are equal. Return "failure" if
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);

/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
///
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
/// index-typed.
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);

/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
Expand Down Expand Up @@ -317,9 +324,6 @@ class ValueBoundsConstraintSet
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);

/// Given an affine map with a single result (and map operands), add a new
Expand Down Expand Up @@ -374,13 +378,16 @@ class ValueBoundsConstraintSet
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
int64_t insert(const Variable &var, bool isSymbol = true);

/// Project out the given column in the constraint set.
void projectOut(int64_t pos);

/// Project out all columns for which the condition holds.
void projectOut(function_ref<bool(ValueDim)> condition);

void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);

/// Mapping of columns to values/shape dimensions.
SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
ValueDimList valueDims;
for (Value v : mapOperands)
valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::EQ, map, valueDims);
presburger::BoundType::EQ,
ValueBoundsConstraintSet::Variable(map, mapOperands));
}
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
using namespace mlir;
using namespace mlir::affine;

static FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();

// Reify bound.
Expand Down Expand Up @@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim,
return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
Expand All @@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ struct SelectOpInterface
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
if (cstr.compare(trueValue, dim,
if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
falseValue, dim)) {
/*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
Expand All @@ -121,9 +121,9 @@ struct SelectOpInterface
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
if (cstr.compare(falseValue, dim,
if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
trueValue, dim)) {
/*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();

FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, in, /*dim=*/std::nullopt,
presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
Expand Down
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
return buildExpr(map.getResult(0));
}

static FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();

// Materialize tensor.dim/memref.dim ops.
Expand Down Expand Up @@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim,
return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
Expand All @@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, opOperand->get(),
/*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
presburger::BoundType::UB,
{opOperand->get(),
/*dim=*/i},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
Value materializedSize =
getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
? materializedSize
? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
Expand Down
Loading