-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][scf] Add reductions support to scf.parallel
fusion
#75955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesFull diff: https://github.com/llvm/llvm-project/pull/75955.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7..ea9fbee26fdeb0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -131,29 +131,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
}
/// Prepends operations of firstPloop's body into secondPloop's body.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- OpBuilder b,
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+ OpBuilder builder,
llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
IRMapping firstToSecondPloopIndices;
- firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
- secondPloop.getBody()->getArguments());
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
return;
- b.setInsertionPointToStart(secondPloop.getBody());
- for (auto &op : firstPloop.getBody()->without_terminator())
- b.clone(op, firstToSecondPloopIndices);
+ DominanceInfo dom;
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ newBlock->getTerminator()->erase();
+
+ block1->getTerminator()->erase();
+
+ b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
}
void mlir::scf::naivelyFuseParallelOps(
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
+ SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
for (auto &block : region) {
- SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
+ ploopChains.clear();
+ ploopChains.push_back({});
+
// Not using `walk()` to traverse only top-level parallel loops and also
// make sure that there are no side-effecting ops between the parallel
// loops.
@@ -171,7 +205,7 @@ void mlir::scf::naivelyFuseParallelOps(
// TODO: Handle region side effects properly.
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
}
- for (ArrayRef<ParallelOp> ploops : ploopChains) {
+ for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 8a42b3a1000ed6..9eb09b7828b848 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %init2 = arith.constant 2.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ scf.reduce(%B_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions
+// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: scf.yield
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ scf.reduce(%B_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// %res1 is used as second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init1 = arith.constant 1.0 : f32
+ %init2 = arith.constant 2.0 : f32
+ %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ scf.reduce(%A_elem) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum = arith.addf %B_elem, %res1 : f32
+ scf.reduce(%sum) : f32 {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %1 = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %1 : f32
+ }
+ scf.yield
+ }
+ return %res1, %res2 : f32, f32
+}
+
+// %res1 is used inside second scf.parallel, cannot fuse
+// CHECK-LABEL: func @reductions_use_res_inside
+// CHECK: scf.parallel
+// CHECK: scf.parallel
|
@matthias-springer FYI, this will require some changes after reductions refactoring is merged |
d824fa5
to
d3a0a84
Compare
Updated to new reductions format |
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please expand a bit the commit message and the comments throughout the code. It may be obvious to you or even someone who worked in this piece of code before, but it won't be to people who may have to bisect or revert this commit for some CI trouble.
auto term1 = cast<ReduceOp>(block1->getTerminator()); | ||
auto term2 = cast<ReduceOp>(block2->getTerminator()); | ||
|
||
b.inlineBlockBefore(block2, newBlock, newBlock->begin(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between:
- Inlining the first loop into the second (old behaviour), and
- Creating a third loop and inlining the second, then the first into it (new behaviour)?
These seem the same to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to change results count, when merging scf.parallel
s with reductions, the only way is to recreate the op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you not just re-generate the scf.reduction
inside the second loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each reduction corresponds to the parent scf.parallel
op result value, so if the first loop had any reductions, those results must be part of the fused parent op, changing total results count.
d3a0a84
to
4513dfd
Compare
4513dfd
to
2dfefdd
Compare
ping |
bf26cf3
to
8f7b4a4
Compare
Properly handle fusion of loops with reductions: * Check there are no first loop results users between loops * Create new loop op with merged reduction init values * Update `scf.reduce` op to contain reductions from both loops * Update loops users with new loop results
Properly handle fusion of loops with reductions: * Check there are no first loop results users between loops * Create new loop op with merged reduction init values * Update `scf.reduce` op to contain reductions from both loops * Update loops users with new loop results
Properly handle fusion of loops with reductions: * Check there are no first loop results users between loops * Create new loop op with merged reduction init values * Update `scf.reduce` op to contain reductions from both loops * Update loops users with new loop results
Properly handle fusion of loops with reductions:
scf.reduce
op to contain reductions from both loops