diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 3543ab52407a3..1d7bc6ea961cc 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -211,7 +211,8 @@ class ValueBoundsConstraintSet /// Comparison operator for `ValueBoundsConstraintSet::compare`. enum ComparisonOperator { LT, LE, EQ, GT, GE }; - /// Try to prove that, based on the current state of this constraint set + /// Populate constraints for lhs/rhs (until the stop condition is met). Then, + /// try to prove that, based on the current state of this constraint set /// (i.e., without analyzing additional IR or adding new constraints), the /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim. /// @@ -220,24 +221,37 @@ class ValueBoundsConstraintSet /// proven. This could be because the specified relation does in fact not hold /// or because there is not enough information in the constraint set. In other /// words, if we do not know for sure, this function returns "false". - bool compare(Value lhs, std::optional lhsDim, ComparisonOperator cmp, - Value rhs, std::optional rhsDim); + bool populateAndCompare(OpFoldResult lhs, std::optional lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional rhsDim); + + /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the + /// specified relation could not be proven. This could be because the + /// specified relation does in fact not hold or because there is not enough + /// information in the constraint set. In other words, if we do not know for + /// sure, this function returns "false". + /// + /// This function keeps traversing the backward slice of lhs/rhs until could + /// prove the relation or until it ran out of IR. + static bool compare(OpFoldResult lhs, std::optional lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional rhsDim); + static bool compare(AffineMap lhs, ValueDimList lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ValueDimList rhsOperands); + static bool compare(AffineMap lhs, ArrayRef lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ArrayRef rhsOperands); /// Compute whether the given values/dimensions are equal. Return "failure" if /// equality could not be determined. /// /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are /// index-typed. - static FailureOr areEqual(Value value1, Value value2, + static FailureOr areEqual(OpFoldResult value1, OpFoldResult value2, std::optional dim1 = std::nullopt, std::optional dim2 = std::nullopt); - /// Compute whether the given values/attributes are equal. Return "failure" if - /// equality could not be determined. - /// - /// `ofr1`/`ofr2` must be of index type. - static FailureOr areEqual(OpFoldResult ofr1, OpFoldResult ofr2); - /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. /// Return "failure" if unknown. @@ -294,6 +308,20 @@ class ValueBoundsConstraintSet ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); + /// Return "true" if, based on the current state of the constraint system, + /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation + /// could not be proven. This could be because the specified relation does in + /// fact not hold or because there is not enough information in the constraint + /// set. In other words, if we do not know for sure, this function returns + /// "false". + /// + /// This function does not analyze any IR and does not populate any additional + /// constraints. + bool compareValueDims(OpFoldResult lhs, std::optional lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional rhsDim); + bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); + /// Given an affine map with a single result (and map operands), add a new /// column to the constraint set that represents the result of the map. /// Traverse additional IR starting from the map operands as needed (as long @@ -319,6 +347,10 @@ class ValueBoundsConstraintSet /// set. AffineExpr getPosExpr(int64_t pos); + /// Return "true" if the given value/dim is mapped (i.e., has a corresponding + /// column in the constraint system). + bool isMapped(Value value, std::optional dim = std::nullopt) const; + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to /// "false", a dimension is added. The value/dimension is added to the /// worklist if `addToWorklist` is set. @@ -338,6 +370,11 @@ class ValueBoundsConstraintSet /// dimensions but not for symbols. int64_t insert(bool isSymbol = true); + /// Insert the given affine map and its bound operands as a new column in the + /// constraint system. Return the position of the new column. Any operands + /// that were not analyzed yet are put on the worklist. + int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + /// Project out the given column in the constraint set. void projectOut(int64_t pos); diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 72c5aaa230678..087ffc438a830 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -58,20 +58,11 @@ struct ForOpInterface Value iterArg = forOp.getRegionIterArg(iterArgIdx); Value initArg = forOp.getInitArgs()[iterArgIdx]; - // Populate constraints for the yielded value. - cstr.populateConstraints(yieldedValue, dim); - // Populate constraints for the iter_arg. This is just to ensure that the - // iter_arg is mapped in the constraint set, which is a prerequisite for - // `compare`. It may lead to a recursive call to this function in case the - // iter_arg was not visited when the constraints for the yielded value were - // populated, but no additional work is done. - cstr.populateConstraints(iterArg, dim); - // An EQ constraint can be added if the yielded value (dimension size) // equals the corresponding block argument (dimension size). - if (cstr.compare(yieldedValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg, - dim)) { + if (cstr.populateAndCompare( + yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ, + iterArg, dim)) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { @@ -113,10 +104,6 @@ struct IfOpInterface Value thenValue = ifOp.thenYield().getResults()[resultNum]; Value elseValue = ifOp.elseYield().getResults()[resultNum]; - // Populate constraints for the yielded value (and all values on the - // backward slice, as long as the current stop condition is not satisfied). - cstr.populateConstraints(thenValue, dim); - cstr.populateConstraints(elseValue, dim); auto boundsBuilder = cstr.bound(value); if (dim) boundsBuilder[*dim]; @@ -125,9 +112,9 @@ struct IfOpInterface // If thenValue <= elseValue: // * result <= elseValue // * result >= thenValue - if (cstr.compare(thenValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { + if (cstr.populateAndCompare( + thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, + elseValue, dim)) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); @@ -139,9 +126,9 @@ struct IfOpInterface // If elseValue <= thenValue: // * result <= thenValue // * result >= elseValue - if (cstr.compare(elseValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { + if (cstr.populateAndCompare( + elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, + thenValue, dim)) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 6e3d6dd3c7575..c138056ab41cc 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -202,6 +202,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { return pos; } +int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, + bool isSymbol) { + assert(map.getNumResults() == 1 && "expected affine map with one result"); + int64_t pos = insert(/*isSymbol=*/false); + + // Add map and operands to the constraint set. Dimensions are converted to + // symbols. All operands are added to the worklist (unless they were already + // processed). + auto mapper = [&](std::pair> v) { + return getExpr(v.first, v.second); + }; + SmallVector dimReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); + SmallVector symReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); + addBound( + presburger::BoundType::EQ, pos, + map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); + + return pos; +} + int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional dim) const { #ifndef NDEBUG @@ -224,6 +246,13 @@ AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); } +bool ValueBoundsConstraintSet::isMapped(Value value, + std::optional dim) const { + auto it = + valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); + return it != valueDimToPosition.end(); +} + static Operation *getOwnerOfValue(Value value) { if (auto bbArg = dyn_cast(value)) return bbArg.getOwner()->getParentOp(); @@ -560,27 +589,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value, int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, ValueDimList operands) { - assert(map.getNumResults() == 1 && "expected affine map with one result"); - int64_t pos = insert(/*isSymbol=*/false); - - // Add map and operands to the constraint set. Dimensions are converted to - // symbols. All operands are added to the worklist (unless they were already - // processed). - auto mapper = [&](std::pair> v) { - return getExpr(v.first, v.second); - }; - SmallVector dimReplacements = llvm::to_vector( - llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); - SmallVector symReplacements = llvm::to_vector( - llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); - addBound( - presburger::BoundType::EQ, pos, - map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); - + int64_t pos = insert(map, operands, /*isSymbol=*/false); // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. processWorklist(); - return pos; } @@ -600,9 +612,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, {{value1, dim1}, {value2, dim2}}); } -bool ValueBoundsConstraintSet::compare(Value lhs, std::optional lhsDim, - ComparisonOperator cmp, Value rhs, - std::optional rhsDim) { +bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, + std::optional lhsDim, + ComparisonOperator cmp, + OpFoldResult rhs, + std::optional rhsDim) { +#ifndef NDEBUG + if (auto lhsVal = dyn_cast(lhs)) + assertValidValueDim(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast(rhs)) + assertValidValueDim(rhsVal, rhsDim); +#endif // NDEBUG + // This function returns "true" if "lhs CMP rhs" is proven to hold. // // Example for ComparisonOperator::LE and index-typed values: We would like to @@ -621,24 +642,32 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional lhsDim, // EQ can be expressed as LE and GE. if (cmp == EQ) - return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && - compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); + return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && + compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); // Construct inequality. For the above example: lhs > rhs. // `IntegerRelation` inequalities are expressed in the "flattened" form and // with ">= 0". I.e., lhs - rhs - 1 >= 0. - SmallVector eq(cstr.getNumDimAndSymbolVars() + 1, 0); + SmallVector eq(cstr.getNumCols(), 0); + auto addToEq = [&](OpFoldResult ofr, std::optional dim, + int64_t factor) { + if (auto constVal = ::getConstantIntValue(ofr)) { + eq[cstr.getNumCols() - 1] += *constVal * factor; + } else { + eq[getPos(cast(ofr), dim)] += factor; + } + }; if (cmp == LT || cmp == LE) { - ++eq[getPos(lhs, lhsDim)]; - --eq[getPos(rhs, rhsDim)]; + addToEq(lhs, lhsDim, 1); + addToEq(rhs, rhsDim, -1); } else if (cmp == GT || cmp == GE) { - --eq[getPos(lhs, lhsDim)]; - ++eq[getPos(rhs, rhsDim)]; + addToEq(lhs, lhsDim, -1); + addToEq(rhs, rhsDim, 1); } else { llvm_unreachable("unsupported comparison operator"); } if (cmp == LE || cmp == GE) - eq[cstr.getNumDimAndSymbolVars()] -= 1; + eq[cstr.getNumCols() - 1] -= 1; // Add inequality to the constraint set and check if it made the constraint // set empty. @@ -649,40 +678,128 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional lhsDim, return isEmpty; } +bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, + ComparisonOperator cmp, + int64_t rhsPos) { + // This function returns "true" if "lhs CMP rhs" is proven to hold. For + // detailed documentation, see `compareValueDims`. + + // EQ can be expressed as LE and GE. + if (cmp == EQ) + return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && + comparePos(lhsPos, ComparisonOperator::GE, rhsPos); + + // Construct inequality. + SmallVector eq(cstr.getNumCols(), 0); + if (cmp == LT || cmp == LE) { + ++eq[lhsPos]; + --eq[rhsPos]; + } else if (cmp == GT || cmp == GE) { + --eq[lhsPos]; + ++eq[rhsPos]; + } else { + llvm_unreachable("unsupported comparison operator"); + } + if (cmp == LE || cmp == GE) + eq[cstr.getNumCols() - 1] -= 1; + + // Add inequality to the constraint set and check if it made the constraint + // set empty. + int64_t ineqPos = cstr.getNumInequalities(); + cstr.addInequality(eq); + bool isEmpty = cstr.isEmpty(); + cstr.removeInequality(ineqPos); + return isEmpty; +} + +bool ValueBoundsConstraintSet::populateAndCompare( + OpFoldResult lhs, std::optional lhsDim, ComparisonOperator cmp, + OpFoldResult rhs, std::optional rhsDim) { +#ifndef NDEBUG + if (auto lhsVal = dyn_cast(lhs)) + assertValidValueDim(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast(rhs)) + assertValidValueDim(rhsVal, rhsDim); +#endif // NDEBUG + + if (auto lhsVal = dyn_cast(lhs)) + populateConstraints(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast(rhs)) + populateConstraints(rhsVal, rhsDim); + + return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); +} + +bool ValueBoundsConstraintSet::compare(OpFoldResult lhs, + std::optional lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional rhsDim) { + auto stopCondition = [&](Value v, std::optional dim, + ValueBoundsConstraintSet &cstr) { + // Keep processing as long as lhs/rhs are not mapped. + if (auto lhsVal = dyn_cast(lhs)) + if (!cstr.isMapped(lhsVal, dim)) + return false; + if (auto rhsVal = dyn_cast(rhs)) + if (!cstr.isMapped(rhsVal, dim)) + return false; + // Keep processing as long as the relation cannot be proven. + return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); + }; + + ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); + return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim); +} + +bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ValueDimList rhsOperands) { + int64_t lhsPos = -1, rhsPos = -1; + auto stopCondition = [&](Value v, std::optional dim, + ValueBoundsConstraintSet &cstr) { + // Keep processing as long as lhs/rhs were not processed. + if (lhsPos >= cstr.positionToValueDim.size() || + rhsPos >= cstr.positionToValueDim.size()) + return false; + // Keep processing as long as the relation cannot be proven. + return cstr.comparePos(lhsPos, cmp, rhsPos); + }; + ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); + lhsPos = cstr.insert(lhs, lhsOperands); + rhsPos = cstr.insert(rhs, rhsOperands); + cstr.processWorklist(); + return cstr.comparePos(lhsPos, cmp, rhsPos); +} + +bool ValueBoundsConstraintSet::compare(AffineMap lhs, + ArrayRef lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ArrayRef rhsOperands) { + ValueDimList lhsValueDimOperands = + llvm::map_to_vector(lhsOperands, [](Value v) { + return std::make_pair(v, std::optional()); + }); + ValueDimList rhsValueDimOperands = + llvm::map_to_vector(rhsOperands, [](Value v) { + return std::make_pair(v, std::optional()); + }); + return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs, + rhsValueDimOperands); +} + FailureOr -ValueBoundsConstraintSet::areEqual(Value value1, Value value2, +ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2, std::optional dim1, std::optional dim2) { - // Subtract the two values/dimensions from each other. If the result is 0, - // both are equal. - FailureOr delta = computeConstantDelta(value1, value2, dim1, dim2); - if (failed(delta)) - return failure(); - return *delta == 0; -} - -FailureOr ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1, - OpFoldResult ofr2) { - Builder b(ofr1.getContext()); - AffineMap map = - AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2, - b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1)); - SmallVector ofrOperands; - ofrOperands.push_back(ofr1); - ofrOperands.push_back(ofr2); - SmallVector valueOperands; - AffineMap foldedMap = - foldAttributesIntoMap(b, map, ofrOperands, valueOperands); - ValueDimList valueDims; - for (Value v : valueOperands) { - assert(v.getType().isIndex() && "expected index type"); - valueDims.emplace_back(v, std::nullopt); - } - FailureOr delta = - computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims); - if (failed(delta)) - return failure(); - return *delta == 0; + if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ, + value2, dim2)) + return true; + if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT, + value2, dim2) || + ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT, + value2, dim2)) + return false; + return failure(); } FailureOr diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 55282e8334abd..10da91870f49d 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -79,6 +79,17 @@ func.func @composed_affine_apply(%i1 : index) -> (index) { } +// ----- + +func.func @are_equal(%i1 : index) { + %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1) + %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) + %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3] + // expected-remark @below{{false}} + "test.compare"(%i2, %i3) : (index, index) -> () + return +} + // ----- // Test for affine::fullyComposeAndCheckIfEqual @@ -87,6 +98,36 @@ func.func @composed_are_equal(%i1 : index) { %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3] // expected-remark @below{{different}} - "test.are_equal"(%i2, %i3) {compose} : (index, index) -> () + "test.compare"(%i2, %i3) {compose} : (index, index) -> () + return +} + +// ----- + +func.func @compare_affine_max(%a: index, %b: index) { + %0 = affine.max affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b] + // expected-remark @below{{true}} + "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> () + // expected-error @below{{unknown}} + "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> () + // expected-remark @below{{false}} + "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> () + // expected-error @below{{unknown}} + "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> () + return +} + +// ----- + +func.func @compare_affine_min(%a: index, %b: index) { + %0 = affine.min affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b] + // expected-error @below{{unknown}} + "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> () + // expected-remark @below{{false}} + "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> () + // expected-error @below{{unknown}} + "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> () return } diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index 0ea06737886d4..9ab03da1c9a94 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -219,3 +219,15 @@ func.func @scf_if_eq(%a: index, %b: index, %c : i1) { "test.some_use"(%reify1) : (index) -> () return } + +// ----- + +func.func @compare_scf_for(%a: index, %b: index, %c: index) { + scf.for %iv = %a to %b step %c { + // expected-remark @below{{true}} + "test.compare"(%iv, %a) {cmp = "GE"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%iv, %b) {cmp = "LT"} : (index, index) -> () + } + return +} diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir index 45520da6aeb0b..0c90bcdb42028 100644 --- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -163,8 +163,8 @@ func.func @dynamic_dims_are_equal(%t: tensor) { %c0 = arith.constant 0 : index %dim0 = tensor.dim %t, %c0 : tensor %dim1 = tensor.dim %t, %c0 : tensor - // expected-remark @below {{equal}} - "test.are_equal"(%dim0, %dim1) : (index, index) -> () + // expected-remark @below {{true}} + "test.compare"(%dim0, %dim1) : (index, index) -> () return } @@ -175,8 +175,8 @@ func.func @dynamic_dims_are_different(%t: tensor) { %c1 = arith.constant 1 : index %dim0 = tensor.dim %t, %c0 : tensor %val = arith.addi %dim0, %c1 : index - // expected-remark @below {{different}} - "test.are_equal"(%dim0, %val) : (index, index) -> () + // expected-remark @below {{false}} + "test.compare"(%dim0, %val) : (index, index) -> () return } @@ -186,8 +186,8 @@ func.func @dynamic_dims_are_maybe_equal_1(%t: tensor) { %c0 = arith.constant 0 : index %c5 = arith.constant 5 : index %dim0 = tensor.dim %t, %c0 : tensor - // expected-error @below {{could not determine equality}} - "test.are_equal"(%dim0, %c5) : (index, index) -> () + // expected-error @below {{unknown}} + "test.compare"(%dim0, %c5) : (index, index) -> () return } @@ -198,7 +198,7 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor) { %c1 = arith.constant 1 : index %dim0 = tensor.dim %t, %c0 : tensor %dim1 = tensor.dim %t, %c1 : tensor - // expected-error @below {{could not determine equality}} - "test.are_equal"(%dim0, %dim1) : (index, index) -> () + // expected-error @below {{unknown}} + "test.compare"(%dim0, %dim1) : (index, index) -> () return } diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 4b2b1a06341b7..f38631054fb3c 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -57,7 +57,7 @@ struct TestReifyValueBounds } // namespace -FailureOr parseBoundType(const std::string &type) { +static FailureOr parseBoundType(const std::string &type) { if (type == "EQ") return BoundType::EQ; if (type == "LB") @@ -67,6 +67,34 @@ FailureOr parseBoundType(const std::string &type) { return failure(); } +static FailureOr +parseComparisonOperator(const std::string &type) { + if (type == "EQ") + return ValueBoundsConstraintSet::ComparisonOperator::EQ; + if (type == "LT") + return ValueBoundsConstraintSet::ComparisonOperator::LT; + if (type == "LE") + return ValueBoundsConstraintSet::ComparisonOperator::LE; + if (type == "GT") + return ValueBoundsConstraintSet::ComparisonOperator::GT; + if (type == "GE") + return ValueBoundsConstraintSet::ComparisonOperator::GE; + return failure(); +} + +static ValueBoundsConstraintSet::ComparisonOperator +invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) { + if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT) + return ValueBoundsConstraintSet::ComparisonOperator::GE; + if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE) + return ValueBoundsConstraintSet::ComparisonOperator::GT; + if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT) + return ValueBoundsConstraintSet::ComparisonOperator::LE; + if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE) + return ValueBoundsConstraintSet::ComparisonOperator::LT; + llvm_unreachable("unsupported comparison operator"); +} + /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. static LogicalResult testReifyValueBounds(func::FuncOp funcOp, @@ -215,18 +243,34 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, return failure(result.wasInterrupted()); } -/// Look for "test.are_equal" ops and emit errors/remarks. +/// Look for "test.compare" ops and emit errors/remarks. static LogicalResult testEquality(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); WalkResult result = funcOp.walk([&](Operation *op) { - // Look for test.are_equal ops. - if (op->getName().getStringRef() == "test.are_equal") { + // Look for test.compare ops. + if (op->getName().getStringRef() == "test.compare") { if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() || !op->getOperand(1).getType().isIndex()) { op->emitOpError("invalid op"); return WalkResult::skip(); } + + // Get comparison operator. + std::string cmpStr = "EQ"; + if (auto cmpAttr = op->getAttrOfType("cmp")) + cmpStr = cmpAttr.str(); + auto cmpType = parseComparisonOperator(cmpStr); + if (failed(cmpType)) { + op->emitOpError("invalid comparison operator"); + return WalkResult::interrupt(); + } + if (op->hasAttr("compose")) { + if (cmpType != ValueBoundsConstraintSet::EQ) { + op->emitOpError( + "comparison operator must be EQ when 'composed' is specified"); + return WalkResult::interrupt(); + } FailureOr delta = affine::fullyComposeAndComputeConstantDelta( op->getOperand(0), op->getOperand(1)); if (failed(delta)) { @@ -236,16 +280,25 @@ static LogicalResult testEquality(func::FuncOp funcOp) { } else { op->emitRemark("different"); } + return WalkResult::advance(); + } + + auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { + return ValueBoundsConstraintSet::compare( + /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp, + /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt); + }; + if (compare(*cmpType)) { + op->emitRemark("true"); + } else if (*cmpType != ValueBoundsConstraintSet::EQ && + compare(invertComparisonOperator(*cmpType))) { + op->emitRemark("false"); + } else if (*cmpType == ValueBoundsConstraintSet::EQ && + (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) || + compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) { + op->emitRemark("false"); } else { - FailureOr equal = ValueBoundsConstraintSet::areEqual( - op->getOperand(0), op->getOperand(1)); - if (failed(equal)) { - op->emitError("could not determine equality"); - } else if (*equal) { - op->emitRemark("equal"); - } else { - op->emitRemark("different"); - } + op->emitError("unknown"); } } return WalkResult::advance();