-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][scf] Align scf.while
before
block args in canonicalizer
#76195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesIf Full diff: https://github.com/llvm/llvm-project/pull/76195.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..de320723ce83f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3872,6 +3872,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// If both ranges contain same values return mappping indices from args1 to
+/// args2. Otherwise return std::nullopt
+static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
+ ValueRange args2) {
+ if (args1.size() != args2.size())
+ return std::nullopt;
+
+ SmallVector<unsigned> ret(args1.size());
+ for (auto &&[i, arg1] : llvm::enumerate(args1)) {
+ auto it = llvm::find(args2, arg1);
+ if (it == args2.end())
+ return std::nullopt;
+
+ auto j = it - args2.begin();
+ ret[j] = static_cast<unsigned>(i);
+ }
+
+ return ret;
+}
+
+/// If `before` block args are directly forwarded to `scf.condition`, rearrange
+/// `scf.condition` args into same order as block args. Update `after` block
+// args and results values accordingly.
+/// Needed to simplify `scf.while` -> `scf.for` uplifting.
+struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ auto oldBefore = loop.getBeforeBody();
+ ConditionOp oldTerm = loop.getConditionOp();
+ ValueRange beforeArgs = oldBefore->getArguments();
+ ValueRange termArgs = oldTerm.getArgs();
+ if (beforeArgs == termArgs)
+ return failure();
+
+ auto mapping = getArgsMapping(beforeArgs, termArgs);
+ if (!mapping)
+ return failure();
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(oldTerm);
+ rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
+ beforeArgs);
+ }
+
+ auto oldAfter = loop.getAfterBody();
+
+ SmallVector<Type> newResultTypes(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping))
+ newResultTypes[j] = loop.getResult(i).getType();
+
+ auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes,
+ loop.getInits(), nullptr, nullptr);
+ auto newBefore = newLoop.getBeforeBody();
+ auto newAfter = newLoop.getAfterBody();
+
+ SmallVector<Value> newResults(beforeArgs.size());
+ SmallVector<Value> newAfterArgs(beforeArgs.size());
+ for (auto &&[i, j] : llvm::enumerate(*mapping)) {
+ newResults[i] = newLoop.getResult(j);
+ newAfterArgs[i] = newAfter->getArgument(j);
+ }
+
+ rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+ newBefore->getArguments());
+ rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+ newAfterArgs);
+
+ rewriter.replaceOp(loop, newResults);
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -3879,7 +3954,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..b4c9ed4db94e0e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
// CHECK: return %[[RES]] : i32
+// -----
+
+// CHECK-LABEL: func @test_align_args
+// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
+// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
+// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
+// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
+// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
+// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
+// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
+func.func @test_align_args() -> (i64, f32, i32) {
+ %0 = "test.test"() : () -> (f32)
+ %1 = "test.test"() : () -> (i32)
+ %2 = "test.test"() : () -> (i64)
+ %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
+ %cond = "test.test"() : () -> (i1)
+ scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
+ } do {
+ ^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
+ %4 = "test.test"(%arg3) : (i64) -> (f32)
+ %5 = "test.test"(%arg4) : (f32) -> (i32)
+ %6 = "test.test"(%arg5) : (i32) -> (i64)
+ scf.yield %4, %5, %6 : f32, i32, i64
+ }
+ return %3#0, %3#1, %3#2 : i64, f32, i32
+}
+
+
// -----
// CHECK-LABEL: @combineIfs
|
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
return std::nullopt; | ||
|
||
auto j = it - args2.begin(); | ||
ret[j] = static_cast<unsigned>(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ret[std::distance(arg2.begin(), it)] = i
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
|
||
/// If `before` block args are directly forwarded to `scf.condition`, rearrange | ||
/// `scf.condition` args into same order as block args. Update `after` block | ||
// args and results values accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result values
, a /
is missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
newResultTypes[j] = loop.getResult(i).getType(); | ||
|
||
auto newLoop = rewriter.create<WhileOp>(loop.getLoc(), newResultTypes, | ||
loop.getInits(), nullptr, nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spell out variable names for nullptr
arguments, i.e., /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better: we should be able to remove these two nullptr
entirely, there is another build method that should match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only other build method is default-generated static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
and it does not populate regions IIRC.
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { | |||
return success(); | |||
} | |||
}; | |||
|
|||
/// If both ranges contain same values return mappping indices from args1 to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this actually a mapping from args2
to args1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
SmallVector<unsigned> ret(args1.size()); | ||
for (auto &&[i, arg1] : llvm::enumerate(args1)) { | ||
auto it = llvm::find(args2, arg1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will break if there are duplicate values in args1
and args2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example where we'd want to find a mapping but there are duplicates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value range could contain the same value multiple times. eg if the same value is yielded multiple times from the before block. I think then this function no longer returns a permutation. actually, maybe that is fine...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the args is block args so it will never contain duplicates, the other is scf.condition
args. There is a pattern to cleanup duplicated scf.condition
args, but order in which patterns are applied is undetermined. I can add a code to check for duplicated scf.condition
args and bail out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OR, I can set higher benefit to duplicates cleanup pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added duplicates check
16d1144
to
24797e2
Compare
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting llvm#76108
b2c3e8e
to
1f8b1b1
Compare
If
before
block args are directly forwarded toscf.condition
make sure they are passed in the same order.This is needed for
scf.while
uplifting #76108