Skip to content

Commit 6050cf2

Browse files
authored
[mlir][scf] Add reductions support to scf.parallel fusion (#75955)
Properly handle fusion of loops with reductions: * Check there are no first loop results users between loops * Create new loop op with merged reduction init values * Update `scf.reduce` op to contain reductions from both loops * Update loops users with new loop results
1 parent fcd3752 commit 6050cf2

File tree

2 files changed

+304
-10
lines changed

2 files changed

+304
-10
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,29 +161,85 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
161161
}
162162

163163
/// Prepends operations of firstPloop's body into secondPloop's body.
164-
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
165-
OpBuilder b,
164+
/// Updates secondPloop with new loop.
165+
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
166+
OpBuilder builder,
166167
llvm::function_ref<bool(Value, Value)> mayAlias) {
168+
Block *block1 = firstPloop.getBody();
169+
Block *block2 = secondPloop.getBody();
167170
IRMapping firstToSecondPloopIndices;
168-
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
169-
secondPloop.getBody()->getArguments());
171+
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
170172

171173
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
172174
mayAlias))
173175
return;
174176

175-
b.setInsertionPointToStart(secondPloop.getBody());
176-
for (auto &op : firstPloop.getBody()->without_terminator())
177-
b.clone(op, firstToSecondPloopIndices);
177+
DominanceInfo dom;
178+
// We are fusing first loop into second, make sure there are no users of the
179+
// first loop results between loops.
180+
for (Operation *user : firstPloop->getUsers())
181+
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182+
return;
183+
184+
ValueRange inits1 = firstPloop.getInitVals();
185+
ValueRange inits2 = secondPloop.getInitVals();
186+
187+
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188+
newInitVars.append(inits2.begin(), inits2.end());
189+
190+
IRRewriter b(builder);
191+
b.setInsertionPoint(secondPloop);
192+
auto newSecondPloop = b.create<ParallelOp>(
193+
secondPloop.getLoc(), secondPloop.getLowerBound(),
194+
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195+
196+
Block *newBlock = newSecondPloop.getBody();
197+
auto term1 = cast<ReduceOp>(block1->getTerminator());
198+
auto term2 = cast<ReduceOp>(block2->getTerminator());
199+
200+
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201+
newBlock->getArguments());
202+
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203+
newBlock->getArguments());
204+
205+
ValueRange results = newSecondPloop.getResults();
206+
if (!results.empty()) {
207+
b.setInsertionPointToEnd(newBlock);
208+
209+
ValueRange reduceArgs1 = term1.getOperands();
210+
ValueRange reduceArgs2 = term2.getOperands();
211+
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212+
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213+
214+
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215+
216+
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217+
term1.getReductions(), term2.getReductions()))) {
218+
Block &oldRedBlock = reg.front();
219+
Block &newRedBlock = newReduceOp.getReductions()[i].front();
220+
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221+
newRedBlock.getArguments());
222+
}
223+
224+
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225+
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226+
}
227+
term1->erase();
228+
term2->erase();
178229
firstPloop.erase();
230+
secondPloop.erase();
231+
secondPloop = newSecondPloop;
179232
}
180233

181234
void mlir::scf::naivelyFuseParallelOps(
182235
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
183236
OpBuilder b(region);
184237
// Consider every single block and attempt to fuse adjacent loops.
238+
SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
185239
for (auto &block : region) {
186-
SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
240+
ploopChains.clear();
241+
ploopChains.push_back({});
242+
187243
// Not using `walk()` to traverse only top-level parallel loops and also
188244
// make sure that there are no side-effecting ops between the parallel
189245
// loops.
@@ -201,7 +257,7 @@ void mlir::scf::naivelyFuseParallelOps(
201257
// TODO: Handle region side effects properly.
202258
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
203259
}
204-
for (ArrayRef<ParallelOp> ploops : ploopChains) {
260+
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
205261
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
206262
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
207263
}

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
2424

2525
// -----
2626

27+
func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 {
28+
%c2 = arith.constant 2 : index
29+
%c0 = arith.constant 0 : index
30+
%c1 = arith.constant 1 : index
31+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
32+
scf.reduce
33+
}
34+
%res = arith.addf %A, %B : f32
35+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
36+
scf.reduce
37+
}
38+
return %res : f32
39+
}
40+
// CHECK-LABEL: func @fuse_ops_between
41+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
42+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
43+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
44+
// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47+
// CHECK: scf.reduce
48+
// CHECK: }
49+
// CHECK-NOT: scf.parallel
50+
51+
// -----
52+
2753
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
2854
%c2 = arith.constant 2 : index
2955
%c0 = arith.constant 0 : index
@@ -89,7 +115,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
89115
memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
90116
scf.reduce
91117
}
92-
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
118+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
93119
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
94120
%res_elem = arith.addf %A_elem, %c2fp : f32
95121
memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
@@ -575,3 +601,215 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
575601
// CHECK-NEXT: }
576602
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577603
// CHECK-NEXT: return
604+
605+
// -----
606+
607+
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
608+
%c2 = arith.constant 2 : index
609+
%c0 = arith.constant 0 : index
610+
%c1 = arith.constant 1 : index
611+
%init1 = arith.constant 1.0 : f32
612+
%init2 = arith.constant 2.0 : f32
613+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
614+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
615+
scf.reduce(%A_elem : f32) {
616+
^bb0(%lhs: f32, %rhs: f32):
617+
%1 = arith.addf %lhs, %rhs : f32
618+
scf.reduce.return %1 : f32
619+
}
620+
}
621+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
622+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
623+
scf.reduce(%B_elem : f32) {
624+
^bb0(%lhs: f32, %rhs: f32):
625+
%1 = arith.mulf %lhs, %rhs : f32
626+
scf.reduce.return %1 : f32
627+
}
628+
}
629+
return %res1, %res2 : f32, f32
630+
}
631+
632+
// CHECK-LABEL: func @fuse_reductions_two
633+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
634+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
635+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
636+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
637+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
638+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
639+
// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
640+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
641+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
642+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
643+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
644+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
645+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
646+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
647+
// CHECK: scf.reduce.return %[[R]] : f32
648+
// CHECK: }
649+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
650+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
651+
// CHECK: scf.reduce.return %[[R]] : f32
652+
// CHECK: }
653+
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
654+
655+
// -----
656+
657+
func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
658+
%c2 = arith.constant 2 : index
659+
%c0 = arith.constant 0 : index
660+
%c1 = arith.constant 1 : index
661+
%init1 = arith.constant 1.0 : f32
662+
%init2 = arith.constant 2.0 : f32
663+
%init3 = arith.constant 3.0 : f32
664+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
665+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
666+
scf.reduce(%A_elem : f32) {
667+
^bb0(%lhs: f32, %rhs: f32):
668+
%1 = arith.addf %lhs, %rhs : f32
669+
scf.reduce.return %1 : f32
670+
}
671+
}
672+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
673+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
674+
scf.reduce(%B_elem : f32) {
675+
^bb0(%lhs: f32, %rhs: f32):
676+
%1 = arith.mulf %lhs, %rhs : f32
677+
scf.reduce.return %1 : f32
678+
}
679+
}
680+
%res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
681+
%A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
682+
scf.reduce(%A_elem : f32) {
683+
^bb0(%lhs: f32, %rhs: f32):
684+
%1 = arith.addf %lhs, %rhs : f32
685+
scf.reduce.return %1 : f32
686+
}
687+
}
688+
return %res1, %res2, %res3 : f32, f32, f32
689+
}
690+
691+
// CHECK-LABEL: func @fuse_reductions_three
692+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
693+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
694+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
695+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
696+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
697+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
698+
// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
699+
// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
700+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
701+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
702+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
703+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
704+
// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
705+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
706+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
707+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
708+
// CHECK: scf.reduce.return %[[R]] : f32
709+
// CHECK: }
710+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
711+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
712+
// CHECK: scf.reduce.return %[[R]] : f32
713+
// CHECK: }
714+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
715+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
716+
// CHECK: scf.reduce.return %[[R]] : f32
717+
// CHECK: }
718+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
719+
720+
// -----
721+
722+
func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
723+
%c2 = arith.constant 2 : index
724+
%c0 = arith.constant 0 : index
725+
%c1 = arith.constant 1 : index
726+
%init1 = arith.constant 1.0 : f32
727+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
728+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
729+
scf.reduce(%A_elem : f32) {
730+
^bb0(%lhs: f32, %rhs: f32):
731+
%1 = arith.addf %lhs, %rhs : f32
732+
scf.reduce.return %1 : f32
733+
}
734+
}
735+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
736+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
737+
scf.reduce(%B_elem : f32) {
738+
^bb0(%lhs: f32, %rhs: f32):
739+
%1 = arith.mulf %lhs, %rhs : f32
740+
scf.reduce.return %1 : f32
741+
}
742+
}
743+
return %res1, %res2 : f32, f32
744+
}
745+
746+
// %res1 is used as second scf.parallel arg, cannot fuse
747+
// CHECK-LABEL: func @reductions_use_res
748+
// CHECK: scf.parallel
749+
// CHECK: scf.parallel
750+
751+
// -----
752+
753+
func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
754+
%c2 = arith.constant 2 : index
755+
%c0 = arith.constant 0 : index
756+
%c1 = arith.constant 1 : index
757+
%init1 = arith.constant 1.0 : f32
758+
%init2 = arith.constant 2.0 : f32
759+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
760+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
761+
scf.reduce(%A_elem : f32) {
762+
^bb0(%lhs: f32, %rhs: f32):
763+
%1 = arith.addf %lhs, %rhs : f32
764+
scf.reduce.return %1 : f32
765+
}
766+
}
767+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
768+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
769+
%sum = arith.addf %B_elem, %res1 : f32
770+
scf.reduce(%sum : f32) {
771+
^bb0(%lhs: f32, %rhs: f32):
772+
%1 = arith.mulf %lhs, %rhs : f32
773+
scf.reduce.return %1 : f32
774+
}
775+
}
776+
return %res1, %res2 : f32, f32
777+
}
778+
779+
// %res1 is used inside second scf.parallel, cannot fuse
780+
// CHECK-LABEL: func @reductions_use_res_inside
781+
// CHECK: scf.parallel
782+
// CHECK: scf.parallel
783+
784+
// -----
785+
786+
func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) {
787+
%c2 = arith.constant 2 : index
788+
%c0 = arith.constant 0 : index
789+
%c1 = arith.constant 1 : index
790+
%init1 = arith.constant 1.0 : f32
791+
%init2 = arith.constant 2.0 : f32
792+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
793+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
794+
scf.reduce(%A_elem : f32) {
795+
^bb0(%lhs: f32, %rhs: f32):
796+
%1 = arith.addf %lhs, %rhs : f32
797+
scf.reduce.return %1 : f32
798+
}
799+
}
800+
%res3 = arith.addf %res1, %init2 : f32
801+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
802+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
803+
scf.reduce(%B_elem : f32) {
804+
^bb0(%lhs: f32, %rhs: f32):
805+
%1 = arith.mulf %lhs, %rhs : f32
806+
scf.reduce.return %1 : f32
807+
}
808+
}
809+
return %res1, %res2, %res3 : f32, f32, f32
810+
}
811+
812+
// instruction in between the loops uses the first loop result
813+
// CHECK-LABEL: func @reductions_use_res_between
814+
// CHECK: scf.parallel
815+
// CHECK: scf.parallel

0 commit comments

Comments
 (0)