diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5bca8e85f889d..7a1aafc9f1c2f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3884,6 +3884,95 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern { return success(); } }; + +/// If both ranges contain same values return mappping indices from args2 to +/// args1. Otherwise return std::nullopt. +static std::optional> getArgsMapping(ValueRange args1, + ValueRange args2) { + if (args1.size() != args2.size()) + return std::nullopt; + + SmallVector ret(args1.size()); + for (auto &&[i, arg1] : llvm::enumerate(args1)) { + auto it = llvm::find(args2, arg1); + if (it == args2.end()) + return std::nullopt; + + ret[std::distance(args2.begin(), it)] = static_cast(i); + } + + return ret; +} + +static bool hasDuplicates(ValueRange args) { + llvm::SmallDenseSet set; + for (Value arg : args) { + if (set.contains(arg)) + return true; + + set.insert(arg); + } + return false; +} + +/// If `before` block args are directly forwarded to `scf.condition`, rearrange +/// `scf.condition` args into same order as block args. Update `after` block +/// args and op result values accordingly. +/// Needed to simplify `scf.while` -> `scf.for` uplifting. +struct WhileOpAlignBeforeArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp loop, + PatternRewriter &rewriter) const override { + auto oldBefore = loop.getBeforeBody(); + ConditionOp oldTerm = loop.getConditionOp(); + ValueRange beforeArgs = oldBefore->getArguments(); + ValueRange termArgs = oldTerm.getArgs(); + if (beforeArgs == termArgs) + return failure(); + + if (hasDuplicates(termArgs)) + return failure(); + + auto mapping = getArgsMapping(beforeArgs, termArgs); + if (!mapping) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(oldTerm); + rewriter.replaceOpWithNewOp(oldTerm, oldTerm.getCondition(), + beforeArgs); + } + + auto oldAfter = loop.getAfterBody(); + + SmallVector newResultTypes(beforeArgs.size()); + for (auto &&[i, j] : llvm::enumerate(*mapping)) + newResultTypes[j] = loop.getResult(i).getType(); + + auto newLoop = rewriter.create( + loop.getLoc(), newResultTypes, loop.getInits(), + /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); + auto newBefore = newLoop.getBeforeBody(); + auto newAfter = newLoop.getAfterBody(); + + SmallVector newResults(beforeArgs.size()); + SmallVector newAfterArgs(beforeArgs.size()); + for (auto &&[i, j] : llvm::enumerate(*mapping)) { + newResults[i] = newLoop.getResult(j); + newAfterArgs[i] = newAfter->getArgument(j); + } + + rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(), + newBefore->getArguments()); + rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(), + newAfterArgs); + + rewriter.replaceOp(loop, newResults); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -3891,7 +3980,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 52e0fdfa36d6c..b4c9ed4db94e0 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 { // CHECK: return %[[RES]] : i32 +// ----- + +// CHECK-LABEL: func @test_align_args +// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) { +// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64 +// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64): +// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32 +// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32 +// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64 +// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64 +// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1 +func.func @test_align_args() -> (i64, f32, i32) { + %0 = "test.test"() : () -> (f32) + %1 = "test.test"() : () -> (i32) + %2 = "test.test"() : () -> (i64) + %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) { + %cond = "test.test"() : () -> (i1) + scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32 + } do { + ^bb0(%arg3: i64, %arg4: f32, %arg5: i32): + %4 = "test.test"(%arg3) : (i64) -> (f32) + %5 = "test.test"(%arg4) : (f32) -> (i32) + %6 = "test.test"(%arg5) : (i32) -> (i64) + scf.yield %4, %5, %6 : f32, i32, i64 + } + return %3#0, %3#1, %3#2 : i64, f32, i32 +} + + // ----- // CHECK-LABEL: @combineIfs