Skip to content

[mlir][scf] Add reductions support to scf.parallel fusion #75955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,29 +161,85 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
}

/// Prepends operations of firstPloop's body into secondPloop's body.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
OpBuilder b,
/// Updates secondPloop with new loop.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
OpBuilder builder,
llvm::function_ref<bool(Value, Value)> mayAlias) {
Block *block1 = firstPloop.getBody();
Block *block2 = secondPloop.getBody();
IRMapping firstToSecondPloopIndices;
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
secondPloop.getBody()->getArguments());
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());

if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
return;

b.setInsertionPointToStart(secondPloop.getBody());
for (auto &op : firstPloop.getBody()->without_terminator())
b.clone(op, firstToSecondPloopIndices);
DominanceInfo dom;
// We are fusing first loop into second, make sure there are no users of the
// first loop results between loops.
for (Operation *user : firstPloop->getUsers())
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
return;

ValueRange inits1 = firstPloop.getInitVals();
ValueRange inits2 = secondPloop.getInitVals();

SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
newInitVars.append(inits2.begin(), inits2.end());

IRRewriter b(builder);
b.setInsertionPoint(secondPloop);
auto newSecondPloop = b.create<ParallelOp>(
secondPloop.getLoc(), secondPloop.getLowerBound(),
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);

Block *newBlock = newSecondPloop.getBody();
auto term1 = cast<ReduceOp>(block1->getTerminator());
auto term2 = cast<ReduceOp>(block2->getTerminator());

b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between:

  • Inlining the first loop into the second (old behaviour), and
  • Creating a third loop and inlining the second, then the first into it (new behaviour)?

These seem the same to me?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to change results count, when merging scf.parallels with reductions, the only way is to recreate the op.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not just re-generate the scf.reduction inside the second loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each reduction corresponds to the parent scf.parallel op result value, so if the first loop had any reductions, those results must be part of the fused parent op, changing total results count.

newBlock->getArguments());
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
newBlock->getArguments());

ValueRange results = newSecondPloop.getResults();
if (!results.empty()) {
b.setInsertionPointToEnd(newBlock);

ValueRange reduceArgs1 = term1.getOperands();
ValueRange reduceArgs2 = term2.getOperands();
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());

auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);

for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
term1.getReductions(), term2.getReductions()))) {
Block &oldRedBlock = reg.front();
Block &newRedBlock = newReduceOp.getReductions()[i].front();
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
newRedBlock.getArguments());
}

firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
}
term1->erase();
term2->erase();
firstPloop.erase();
secondPloop.erase();
secondPloop = newSecondPloop;
}

void mlir::scf::naivelyFuseParallelOps(
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
for (auto &block : region) {
SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
ploopChains.clear();
ploopChains.push_back({});

// Not using `walk()` to traverse only top-level parallel loops and also
// make sure that there are no side-effecting ops between the parallel
// loops.
Expand All @@ -201,7 +257,7 @@ void mlir::scf::naivelyFuseParallelOps(
// TODO: Handle region side effects properly.
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
}
for (ArrayRef<ParallelOp> ploops : ploopChains) {
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
Expand Down
240 changes: 239 additions & 1 deletion mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {

// -----

func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.reduce
}
%res = arith.addf %A, %B : f32
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.reduce
}
return %res : f32
}
// CHECK-LABEL: func @fuse_ops_between
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: scf.reduce
// CHECK: }
// CHECK-NOT: scf.parallel

// -----

func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -89,7 +115,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%res_elem = arith.addf %A_elem, %c2fp : f32
memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
Expand Down Expand Up @@ -575,3 +601,215 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
// CHECK-NEXT: }
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
// CHECK-NEXT: return

// -----

func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%init1 = arith.constant 1.0 : f32
%init2 = arith.constant 2.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
return %res1, %res2 : f32, f32
}

// CHECK-LABEL: func @fuse_reductions_two
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32

// -----

func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%init1 = arith.constant 1.0 : f32
%init2 = arith.constant 2.0 : f32
%init3 = arith.constant 3.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
%A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
return %res1, %res2, %res3 : f32, f32, f32
}

// CHECK-LABEL: func @fuse_reductions_three
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
// CHECK: scf.reduce.return %[[R]] : f32
// CHECK: }
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32

// -----

func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%init1 = arith.constant 1.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
return %res1, %res2 : f32, f32
}

// %res1 is used as second scf.parallel arg, cannot fuse
// CHECK-LABEL: func @reductions_use_res
// CHECK: scf.parallel
// CHECK: scf.parallel

// -----

func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%init1 = arith.constant 1.0 : f32
%init2 = arith.constant 2.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum = arith.addf %B_elem, %res1 : f32
scf.reduce(%sum : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
return %res1, %res2 : f32, f32
}

// %res1 is used inside second scf.parallel, cannot fuse
// CHECK-LABEL: func @reductions_use_res_inside
// CHECK: scf.parallel
// CHECK: scf.parallel

// -----

func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%init1 = arith.constant 1.0 : f32
%init2 = arith.constant 2.0 : f32
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
scf.reduce(%A_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
%res3 = arith.addf %res1, %init2 : f32
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
scf.reduce(%B_elem : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
}
}
return %res1, %res2, %res3 : f32, f32, f32
}

// instruction in between the loops uses the first loop result
// CHECK-LABEL: func @reductions_use_res_between
// CHECK: scf.parallel
// CHECK: scf.parallel