Skip to content

[mlir][scf] Add reductions support to scf.parallel fusion #75955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 1, 2024

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Dec 19, 2023

Properly handle fusion of loops with reductions:

  • Check there are no first loop results users between loops
  • Create new loop op with merged reduction init values
  • Update scf.reduce op to contain reductions from both loops
  • Update loops users with new loop results

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/75955.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+43-9)
  • (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+122)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7..ea9fbee26fdeb0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -131,29 +131,63 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
 }
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                        OpBuilder b,
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+                        OpBuilder builder,
                         llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
   IRMapping firstToSecondPloopIndices;
-  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
-                                secondPloop.getBody()->getArguments());
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
 
   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
                      mayAlias))
     return;
 
-  b.setInsertionPointToStart(secondPloop.getBody());
-  for (auto &op : firstPloop.getBody()->without_terminator())
-    b.clone(op, firstToSecondPloopIndices);
+  DominanceInfo dom;
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+  Block *newBlock = newSecondPloop.getBody();
+  newBlock->getTerminator()->erase();
+
+  block1->getTerminator()->erase();
+
+  b.inlineBlockBefore(block1, newBlock, newBlock->end(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block2, newBlock, newBlock->end(),
+                      newBlock->getArguments());
+
+  ValueRange results = newSecondPloop.getResults();
+  firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+  secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
   firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
 }
 
 void mlir::scf::naivelyFuseParallelOps(
     Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
   // Consider every single block and attempt to fuse adjacent loops.
+  SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
   for (auto &block : region) {
-    SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
+    ploopChains.clear();
+    ploopChains.push_back({});
+
     // Not using `walk()` to traverse only top-level parallel loops and also
     // make sure that there are no side-effecting ops between the parallel
     // loops.
@@ -171,7 +205,7 @@ void mlir::scf::naivelyFuseParallelOps(
       // TODO: Handle region side effects properly.
       noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
     }
-    for (ArrayRef<ParallelOp> ploops : ploopChains) {
+    for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
         fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 8a42b3a1000ed6..9eb09b7828b848 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// CHECK-LABEL: func @fuse_reductions
+//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:   %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
+//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
+//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_A]]) : f32 {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
+//       CHECK:   scf.reduce(%[[VAL_B]]) : f32 {
+//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
+//       CHECK:     scf.reduce.return %[[R]] : f32
+//       CHECK:   }
+//       CHECK:   scf.yield
+//       CHECK:   return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    scf.reduce(%B_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// %res1 is used as second scf.parallel arg, cannot fuse
+// CHECK-LABEL: func @reductions_use_res
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel
+
+// -----
+
+func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init1 = arith.constant 1.0 : f32
+  %init2 = arith.constant 2.0 : f32
+  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    scf.reduce(%A_elem) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum = arith.addf %B_elem, %res1 : f32
+    scf.reduce(%sum) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = arith.mulf %lhs, %rhs : f32
+      scf.reduce.return %1 : f32
+    }
+    scf.yield
+  }
+  return %res1, %res2 : f32, f32
+}
+
+// %res1 is used inside second scf.parallel, cannot fuse
+// CHECK-LABEL: func @reductions_use_res_inside
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel

@Hardcode84
Copy link
Contributor Author

@matthias-springer FYI, this will require some changes after reductions refactoring is merged

@Hardcode84
Copy link
Contributor Author

Updated to new reductions format

@Hardcode84
Copy link
Contributor Author

ping

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand a bit the commit message and the comments throughout the code. It may be obvious to you or even someone who worked in this piece of code before, but it won't be to people who may have to bisect or revert this commit for some CI trouble.

auto term1 = cast<ReduceOp>(block1->getTerminator());
auto term2 = cast<ReduceOp>(block2->getTerminator());

b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between:

  • Inlining the first loop into the second (old behaviour), and
  • Creating a third loop and inlining the second, then the first into it (new behaviour)?

These seem the same to me?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to change results count, when merging scf.parallels with reductions, the only way is to recreate the op.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not just re-generate the scf.reduction inside the second loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each reduction corresponds to the parent scf.parallel op result value, so if the first loop had any reductions, those results must be part of the fused parent op, changing total results count.

@Hardcode84
Copy link
Contributor Author

ping

@Hardcode84 Hardcode84 merged commit 6050cf2 into llvm:main Feb 1, 2024
@Hardcode84 Hardcode84 deleted the scf-fusion-reduce branch February 1, 2024 15:37
smithp35 pushed a commit to smithp35/llvm-project that referenced this pull request Feb 1, 2024
Properly handle fusion of loops with reductions:
* Check there are no first loop results users between loops
* Create new loop op with merged reduction init values
* Update `scf.reduce` op to contain reductions from both loops
* Update loops users with new loop results
carlosgalvezp pushed a commit to carlosgalvezp/llvm-project that referenced this pull request Feb 1, 2024
Properly handle fusion of loops with reductions:
* Check there are no first loop results users between loops
* Create new loop op with merged reduction init values
* Update `scf.reduce` op to contain reductions from both loops
* Update loops users with new loop results
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
Properly handle fusion of loops with reductions:
* Check there are no first loop results users between loops
* Create new loop op with merged reduction init values
* Update `scf.reduce` op to contain reductions from both loops
* Update loops users with new loop results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants