diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index fbd236b648cb8..8a27bf186d1c2 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -70,6 +70,22 @@ struct ForOpInterface cstr.bound(value) == cstr.getExpr(initArg); } } + + if (dim.has_value() || isa(value)) + return; + + // `value` is result of `forOp`, we can prove that: + // %result == %init_arg + trip_count * (%yielded_value - %iter_arg). + // Where trip_count is (ub - lb) / step. + AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound()); + AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound()); + AffineExpr stepExpr = cstr.getExpr(forOp.getStep()); + AffineExpr tripCountExpr = + AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step + AffineExpr oneIterAdvanceExpr = + cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg); + cstr.bound(value) == + cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr); } void populateBoundsForIndexValue(Operation *op, Value value, diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index 6e0c16a9a2b33..b48f38f592dc9 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -267,3 +267,74 @@ func.func @compare_scf_for(%a: index, %b: index, %c: index) { } return } + +// ----- + +func.func @scf_for_result_infer() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %c0) -> index { + %2 = "test.some_use"() : () -> (i1) + %3 = scf.if %2 -> (index) { + %5 = arith.addi %arg, %c1 : index + scf.yield %5 : index + } else { + scf.yield %arg : index + } + scf.yield %3 : index + } + // expected-remark @below{{true}} + "test.compare"(%0, %c10) {cmp = "LE"} : (index, index) -> () + return +} + +// ----- + +func.func @scf_for_result_infer_dynamic_init(%i : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %i) -> index { + %2 = "test.some_use"() : () -> (i1) + %3 = scf.if %2 -> (index) { + %5 = arith.addi %arg, %c1 : index + scf.yield %5 : index + } else { + scf.yield %arg : index + } + scf.yield %3 : index + } + %6 = arith.addi %i, %c10 : index + // expected-remark @below{{true}} + "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> () + return +} + +// ----- + +func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c10 = arith.constant 10 : index + %0 = scf.for %iv = %c0 to %c10 step %c2 iter_args(%arg = %i) -> index { + %2 = "test.some_use"() : () -> (i1) + %3 = scf.if %2 -> (index) { + %5 = arith.addi %arg, %c1 : index + scf.yield %5 : index + } else { + scf.yield %arg : index + } + scf.yield %3 : index + } + %6 = arith.addi %i, %c5 : index + %7 = arith.addi %i, %c4 : index + // expected-remark @below{{true}} + "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> () + // expected-error @below{{unknown}} + "test.compare"(%0, %7) {cmp = "LE"} : (index, index) -> () + return +}