Skip to content

[mlir][scf] Align scf.while before block args in canonicalizer #76195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 2, 2024

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Dec 21, 2023

If before block args are directly forwarded to scf.condition make sure they are passed in the same order.
This is needed for scf.while uplifting #76108

@llvmbot
Copy link
Member

llvmbot commented Dec 21, 2023

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

If before block args are directly forwarded to scf.condition make sure they are passes in the same order.
This is needed for scf.while uplifting #76108


Full diff: https://github.com/llvm/llvm-project/pull/76195.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+76-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+29)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..de320723ce83f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3872,6 +3872,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     return success();
   }
 };
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+                                                           ValueRange args2) {
+  if (args1.size() != args2.size())
+    return std::nullopt;
+
+  SmallVector<unsigned> ret(args1.size());
+  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+    auto it = llvm::find(args2, arg1);
+    if (it == args2.end())
+      return std::nullopt;
+
+    auto j = it - args2.begin();
+    ret[j] = static_cast<unsigned>(i);
+  }
+
+  return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    auto oldBefore = loop.getBeforeBody();
+    ConditionOp oldTerm = loop.getConditionOp();
+    ValueRange beforeArgs = oldBefore->getArguments();
+    ValueRange termArgs = oldTerm.getArgs();
+    if (beforeArgs == termArgs)
+      return failure();
+
+    auto mapping = getArgsMapping(beforeArgs, termArgs);
+    if (!mapping)
+      return failure();
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(oldTerm);
+      rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+                                               beforeArgs);
+    }
+
+    auto oldAfter = loop.getAfterBody();
+
+    SmallVector<Type> newResultTypes(beforeArgs.size());
+    for (auto &&[i, j] : llvm::enumerate(*mapping))
+      newResultTypes[j] = loop.getResult(i).getType();
+
+    auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+                                            loop.getInits(), nullptr, nullptr);
+    auto newBefore = newLoop.getBeforeBody();
+    auto newAfter = newLoop.getAfterBody();
+
+    SmallVector<Value> newResults(beforeArgs.size());
+    SmallVector<Value> newAfterArgs(beforeArgs.size());
+    for (auto &&[i, j] : llvm::enumerate(*mapping)) {
+      newResults[i] = newLoop.getResult(j);
+      newAfterArgs[i] = newAfter->getArgument(j);
+    }
+
+    rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+                               newBefore->getArguments());
+    rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+                               newAfterArgs);
+
+    rewriter.replaceOp(loop, newResults);
+    return success();
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3879,7 +3954,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..b4c9ed4db94e0e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
 // CHECK:         return %[[RES]] : i32
 
 
+// -----
+
+// CHECK-LABEL: func @test_align_args
+//       CHECK:  %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
+//       CHECK:  scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
+//       CHECK:  ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
+//       CHECK:  %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
+//       CHECK:  %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
+//       CHECK:  %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
+//       CHECK:  scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
+//       CHECK:  return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
+func.func @test_align_args() -> (i64, f32, i32) {
+  %0 = "test.test"() : () -> (f32)
+  %1 = "test.test"() : () -> (i32)
+  %2 = "test.test"() : () -> (i64)
+  %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
+    %cond = "test.test"() : () -> (i1)
+    scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
+  } do {
+  ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
+    %4 = "test.test"(%arg3) : (i64) -> (f32)
+    %5 = "test.test"(%arg4) : (f32) -> (i32)
+    %6 = "test.test"(%arg5) : (i32) -> (i64)
+    scf.yield %4, %5, %6 : f32, i32, i64
+  }
+  return %3#0, %3#1, %3#2 : i64, f32, i32
+}
+
+
 // -----
 
 // CHECK-LABEL: @combineIfs

return std::nullopt;

auto j = it - args2.begin();
ret[j] = static_cast<unsigned>(i);
Copy link
Member

@matthias-springer matthias-springer Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ret[std::distance(arg2.begin(), it)] = i

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


/// If `before` block args are directly forwarded to `scf.condition`, rearrange
/// `scf.condition` args into same order as block args. Update `after` block
// args and results values accordingly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result values, a / is missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

newResultTypes[j] = loop.getResult(i).getType();

auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
loop.getInits(), nullptr, nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out variable names for nullptr arguments, i.e., /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better: we should be able to remove these two nullptr entirely, there is another build method that should match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only other build method is default-generated static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); and it does not populate regions IIRC.

@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return success();
}
};

/// If both ranges contain same values return mappping indices from args1 to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this actually a mapping from args2 to args1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


SmallVector<unsigned> ret(args1.size());
for (auto &&[i, arg1] : llvm::enumerate(args1)) {
auto it = llvm::find(args2, arg1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break if there are duplicate values in args1 and args2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example where we'd want to find a mapping but there are duplicates?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value range could contain the same value multiple times. eg if the same value is yielded multiple times from the before block. I think then this function no longer returns a permutation. actually, maybe that is fine...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the args is block args so it will never contain duplicates, the other is scf.condition args. There is a pattern to cleanup duplicated scf.condition args, but order in which patterns are applied is undetermined. I can add a code to check for duplicated scf.condition args and bail out.

Copy link
Contributor Author

@Hardcode84 Hardcode84 Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OR, I can set higher benefit to duplicates cleanup pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added duplicates check

@Hardcode84 Hardcode84 force-pushed the scf-align-while-args branch from 16d1144 to 24797e2 Compare March 30, 2024 13:45
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting llvm#76108
@Hardcode84 Hardcode84 force-pushed the scf-align-while-args branch from b2c3e8e to 1f8b1b1 Compare April 2, 2024 11:06
@Hardcode84 Hardcode84 merged commit a6d932b into llvm:main Apr 2, 2024
@Hardcode84 Hardcode84 deleted the scf-align-while-args branch April 2, 2024 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants