diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 8e9d1021f93e4..72c5aaa230678 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -12,7 +12,6 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; -using presburger::BoundType; namespace mlir { namespace scf { @@ -21,7 +20,28 @@ namespace { struct ForOpInterface : public ValueBoundsOpInterface::ExternalModel { - /// Populate bounds of values/dimensions for iter_args/OpResults. + /// Populate bounds of values/dimensions for iter_args/OpResults. If the + /// value/dimension size does not change in an iteration, we can deduce that + /// it the same as the initial value/dimension. + /// + /// Example 1: + /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor { + /// ... + /// %1 = tensor.insert %f into %arg0[...] : tensor + /// scf.yield %1 : tensor + /// } + /// --> bound(%0)[0] == bound(%t)[0] + /// --> bound(%arg0)[0] == bound(%t)[0] + /// + /// Example 2: + /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor { + /// %sz = tensor.dim %arg0 : tensor + /// %incr = arith.addi %sz, %c1 : index + /// %1 = tensor.empty(%incr) : tensor + /// scf.yield %1 : tensor + /// } + /// --> The yielded tensor dimension size changes with each iteration. Such + /// loops are not supported and no constraints are added. static void populateIterArgBounds(scf::ForOp forOp, Value value, std::optional dim, ValueBoundsConstraintSet &cstr) { @@ -33,59 +53,31 @@ struct ForOpInterface iterArgIdx = llvm::cast(value).getResultNumber(); } - // An EQ constraint can be added if the yielded value (dimension size) - // equals the corresponding block argument (dimension size). Value yieldedValue = cast(forOp.getBody()->getTerminator()) .getOperand(iterArgIdx); Value iterArg = forOp.getRegionIterArg(iterArgIdx); Value initArg = forOp.getInitArgs()[iterArgIdx]; - auto addEqBound = [&]() { + // Populate constraints for the yielded value. + cstr.populateConstraints(yieldedValue, dim); + // Populate constraints for the iter_arg. This is just to ensure that the + // iter_arg is mapped in the constraint set, which is a prerequisite for + // `compare`. It may lead to a recursive call to this function in case the + // iter_arg was not visited when the constraints for the yielded value were + // populated, but no additional work is done. + cstr.populateConstraints(iterArg, dim); + + // An EQ constraint can be added if the yielded value (dimension size) + // equals the corresponding block argument (dimension size). + if (cstr.compare(yieldedValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg, + dim)) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { cstr.bound(value) == initArg; } - }; - - if (yieldedValue == iterArg) { - addEqBound(); - return; - } - - // Compute EQ bound for yielded value. - AffineMap bound; - ValueDimList boundOperands; - LogicalResult status = ValueBoundsConstraintSet::computeBound( - bound, boundOperands, BoundType::EQ, yieldedValue, dim, - [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { - // Stop when reaching a block argument of the loop body. - if (auto bbArg = llvm::dyn_cast(v)) - return bbArg.getOwner()->getParentOp() == forOp; - // Stop when reaching a value that is defined outside of the loop. It - // is impossible to reach an iter_arg from there. - Operation *op = v.getDefiningOp(); - return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr; - }); - if (failed(status)) - return; - if (bound.getNumResults() != 1) - return; - - // Check if computed bound equals the corresponding iter_arg. - Value singleValue = nullptr; - std::optional singleDim; - if (auto dimExpr = dyn_cast(bound.getResult(0))) { - int64_t idx = dimExpr.getPosition(); - singleValue = boundOperands[idx].first; - singleDim = boundOperands[idx].second; - } else if (auto symExpr = dyn_cast(bound.getResult(0))) { - int64_t idx = symExpr.getPosition() + bound.getNumDims(); - singleValue = boundOperands[idx].first; - singleDim = boundOperands[idx].second; } - if (singleValue == iterArg && singleDim == dim) - addEqBound(); } void populateBoundsForIndexValue(Operation *op, Value value,