Skip to content

[mlir][Interfaces] ValueBoundsOpInterface: Add API to compare values #86915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
enum ComparisonOperator { LT, LE, EQ, GT, GE };

/// Try to prove that, based on the current state of this constraint set
/// Populate constraints for lhs/rhs (until the stop condition is met). Then,
/// try to prove that, based on the current state of this constraint set
/// (i.e., without analyzing additional IR or adding new constraints), the
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
///
Expand All @@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
Value rhs, std::optional<int64_t> rhsDim);
bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);

/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
/// specified relation does in fact not hold or because there is not enough
/// information in the constraint set. In other words, if we do not know for
/// sure, this function returns "false".
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we need a helper class of sorts here that groups the results/affinemap and dims/operands together. Just makes it more difficult to mix up + gives a bit more of a consistent lhs cmp rhs interface. The packing and unpacking should folded away during compilation. [just thinking out loud when reading this, not saying required change for this]

Copy link
Member Author

@matthias-springer matthias-springer Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. We could add something like:

class Variable {
  Variable(OpFoldResult);  // asserts that Value/Attribute is index-typed
  Variable(Value);  // asserts that Value is index-typed
  Variable(Value, int64_t);  // asserts that Value is a shaped value
  Variable(OpFoldResult, std::optional<int64_t>);  // must be index-typed+nullopt or shaped value+non-nullopt
  Variable(AffineMap, mapOperands);

 private:
  OpFoldResult ofr;
  std::optional<int64_t> dim;
  AffineMap map;
  SmallVector<Variable> mapOperands;
};

I think then we could have a single entry point into each of computeBound (and variants), computeConstantBound, compare and areEqual. No more overloads needed. (It requires a bit of "flattening" if we allow Variable as mapOperands.)

I'm going to prepare a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New API added in #87980.

ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);
static bool compare(AffineMap lhs, ValueDimList lhsOperands,
ComparisonOperator cmp, AffineMap rhs,
ValueDimList rhsOperands);
static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
ComparisonOperator cmp, AffineMap rhs,
ArrayRef<Value> rhsOperands);

/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
/// index-typed.
static FailureOr<bool> areEqual(Value value1, Value value2,
static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);

/// Compute whether the given values/attributes are equal. Return "failure" if
/// equality could not be determined.
///
/// `ofr1`/`ofr2` must be of index type.
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);

/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
/// Return "failure" if unknown.
Expand Down Expand Up @@ -294,6 +308,20 @@ class ValueBoundsConstraintSet

ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);

/// Return "true" if, based on the current state of the constraint system,
/// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
/// could not be proven. This could be because the specified relation does in
/// fact not hold or because there is not enough information in the constraint
/// set. In other words, if we do not know for sure, this function returns
/// "false".
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, OpFoldResult rhs,
std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);

/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
Expand All @@ -319,6 +347,10 @@ class ValueBoundsConstraintSet
/// set.
AffineExpr getPosExpr(int64_t pos);

/// Return "true" if the given value/dim is mapped (i.e., has a corresponding
/// column in the constraint system).
bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;

/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
/// "false", a dimension is added. The value/dimension is added to the
/// worklist if `addToWorklist` is set.
Expand All @@ -338,6 +370,11 @@ class ValueBoundsConstraintSet
/// dimensions but not for symbols.
int64_t insert(bool isSymbol = true);

/// Insert the given affine map and its bound operands as a new column in the
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);

/// Project out the given column in the constraint set.
void projectOut(int64_t pos);

Expand Down
31 changes: 9 additions & 22 deletions mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,11 @@ struct ForOpInterface
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];

// Populate constraints for the yielded value.
cstr.populateConstraints(yieldedValue, dim);
// Populate constraints for the iter_arg. This is just to ensure that the
// iter_arg is mapped in the constraint set, which is a prerequisite for
// `compare`. It may lead to a recursive call to this function in case the
// iter_arg was not visited when the constraints for the yielded value were
// populated, but no additional work is done.
cstr.populateConstraints(iterArg, dim);

// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.compare(yieldedValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
dim)) {
if (cstr.populateAndCompare(
yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
iterArg, dim)) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
Expand Down Expand Up @@ -113,10 +104,6 @@ struct IfOpInterface
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];

// Populate constraints for the yielded value (and all values on the
// backward slice, as long as the current stop condition is not satisfied).
cstr.populateConstraints(thenValue, dim);
cstr.populateConstraints(elseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
Expand All @@ -125,9 +112,9 @@ struct IfOpInterface
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
if (cstr.compare(thenValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
elseValue, dim)) {
if (cstr.populateAndCompare(
thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
elseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
Expand All @@ -139,9 +126,9 @@ struct IfOpInterface
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
if (cstr.compare(elseValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
thenValue, dim)) {
if (cstr.populateAndCompare(
elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
thenValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
Expand Down
Loading