Skip to content

Commit c2e23b1

Browse files
committed
Update to new reductions format
1 parent e43ede1 commit c2e23b1

File tree

2 files changed

+102
-25
lines changed

2 files changed

+102
-25
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,38 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
162162
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
163163

164164
Block *newBlock = newSecondPloop.getBody();
165-
newBlock->getTerminator()->erase();
165+
auto term1 = cast<ReduceOp>(block1->getTerminator());
166+
auto term2 = cast<ReduceOp>(block2->getTerminator());
166167

167-
block1->getTerminator()->erase();
168-
169-
b.inlineBlockBefore(block1, newBlock, newBlock->end(),
168+
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
170169
newBlock->getArguments());
171-
b.inlineBlockBefore(block2, newBlock, newBlock->end(),
170+
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
172171
newBlock->getArguments());
173172

174173
ValueRange results = newSecondPloop.getResults();
175-
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
176-
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
174+
if (!results.empty()) {
175+
b.setInsertionPointToEnd(newBlock);
176+
177+
ValueRange reduceArgs1 = term1.getOperands();
178+
ValueRange reduceArgs2 = term2.getOperands();
179+
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
180+
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
181+
182+
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
183+
184+
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
185+
term1.getReductions(), term2.getReductions()))) {
186+
Block &oldRedBlock = reg.front();
187+
Block &newRedBlock = newReduceOp.getReductions()[i].front();
188+
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
189+
newRedBlock.getArguments());
190+
}
191+
192+
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
193+
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
194+
}
195+
term1->erase();
196+
term2->erase();
177197
firstPloop.erase();
178198
secondPloop.erase();
179199
secondPloop = newSecondPloop;

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -390,34 +390,32 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
390390

391391
// -----
392392

393-
func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
393+
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
394394
%c2 = arith.constant 2 : index
395395
%c0 = arith.constant 0 : index
396396
%c1 = arith.constant 1 : index
397397
%init1 = arith.constant 1.0 : f32
398398
%init2 = arith.constant 2.0 : f32
399399
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
400400
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
401-
scf.reduce(%A_elem) : f32 {
401+
scf.reduce(%A_elem : f32) {
402402
^bb0(%lhs: f32, %rhs: f32):
403403
%1 = arith.addf %lhs, %rhs : f32
404404
scf.reduce.return %1 : f32
405405
}
406-
scf.yield
407406
}
408407
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
409408
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
410-
scf.reduce(%B_elem) : f32 {
409+
scf.reduce(%B_elem : f32) {
411410
^bb0(%lhs: f32, %rhs: f32):
412411
%1 = arith.mulf %lhs, %rhs : f32
413412
scf.reduce.return %1 : f32
414413
}
415-
scf.yield
416414
}
417415
return %res1, %res2 : f32, f32
418416
}
419417

420-
// CHECK-LABEL: func @fuse_reductions
418+
// CHECK-LABEL: func @fuse_reductions_two
421419
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
422420
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
423421
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -428,44 +426,105 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
428426
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
429427
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
430428
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
431-
// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
429+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
430+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
432431
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
433432
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
434433
// CHECK: scf.reduce.return %[[R]] : f32
435434
// CHECK: }
436-
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
437-
// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
438435
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
439436
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
440437
// CHECK: scf.reduce.return %[[R]] : f32
441438
// CHECK: }
442-
// CHECK: scf.yield
443439
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
444440

445441
// -----
446442

443+
func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
444+
%c2 = arith.constant 2 : index
445+
%c0 = arith.constant 0 : index
446+
%c1 = arith.constant 1 : index
447+
%init1 = arith.constant 1.0 : f32
448+
%init2 = arith.constant 2.0 : f32
449+
%init3 = arith.constant 3.0 : f32
450+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
451+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
452+
scf.reduce(%A_elem : f32) {
453+
^bb0(%lhs: f32, %rhs: f32):
454+
%1 = arith.addf %lhs, %rhs : f32
455+
scf.reduce.return %1 : f32
456+
}
457+
}
458+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
459+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
460+
scf.reduce(%B_elem : f32) {
461+
^bb0(%lhs: f32, %rhs: f32):
462+
%1 = arith.mulf %lhs, %rhs : f32
463+
scf.reduce.return %1 : f32
464+
}
465+
}
466+
%res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
467+
%A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
468+
scf.reduce(%A_elem : f32) {
469+
^bb0(%lhs: f32, %rhs: f32):
470+
%1 = arith.addf %lhs, %rhs : f32
471+
scf.reduce.return %1 : f32
472+
}
473+
}
474+
return %res1, %res2, %res3 : f32, f32, f32
475+
}
476+
477+
// CHECK-LABEL: func @fuse_reductions_three
478+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
479+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
480+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
481+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
482+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
483+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
484+
// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
485+
// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
486+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
487+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
488+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
489+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
490+
// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
491+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
492+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
493+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
494+
// CHECK: scf.reduce.return %[[R]] : f32
495+
// CHECK: }
496+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
497+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
498+
// CHECK: scf.reduce.return %[[R]] : f32
499+
// CHECK: }
500+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
501+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
502+
// CHECK: scf.reduce.return %[[R]] : f32
503+
// CHECK: }
504+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
505+
506+
// -----
507+
447508
func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
448509
%c2 = arith.constant 2 : index
449510
%c0 = arith.constant 0 : index
450511
%c1 = arith.constant 1 : index
451512
%init1 = arith.constant 1.0 : f32
452513
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
453514
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
454-
scf.reduce(%A_elem) : f32 {
515+
scf.reduce(%A_elem : f32) {
455516
^bb0(%lhs: f32, %rhs: f32):
456517
%1 = arith.addf %lhs, %rhs : f32
457518
scf.reduce.return %1 : f32
458519
}
459-
scf.yield
460520
}
461521
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
462522
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
463-
scf.reduce(%B_elem) : f32 {
523+
scf.reduce(%B_elem : f32) {
464524
^bb0(%lhs: f32, %rhs: f32):
465525
%1 = arith.mulf %lhs, %rhs : f32
466526
scf.reduce.return %1 : f32
467527
}
468-
scf.yield
469528
}
470529
return %res1, %res2 : f32, f32
471530
}
@@ -485,22 +544,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
485544
%init2 = arith.constant 2.0 : f32
486545
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
487546
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
488-
scf.reduce(%A_elem) : f32 {
547+
scf.reduce(%A_elem : f32) {
489548
^bb0(%lhs: f32, %rhs: f32):
490549
%1 = arith.addf %lhs, %rhs : f32
491550
scf.reduce.return %1 : f32
492551
}
493-
scf.yield
494552
}
495553
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
496554
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
497555
%sum = arith.addf %B_elem, %res1 : f32
498-
scf.reduce(%sum) : f32 {
556+
scf.reduce(%sum : f32) {
499557
^bb0(%lhs: f32, %rhs: f32):
500558
%1 = arith.mulf %lhs, %rhs : f32
501559
scf.reduce.return %1 : f32
502560
}
503-
scf.yield
504561
}
505562
return %res1, %res2 : f32, f32
506563
}

0 commit comments

Comments
 (0)