Skip to content

Commit 0e944a3

Browse files
authored
[SCFToGPU] Convert scf.parallel+scf.reduce to gpu.all_reduce (#122782)
Support reductions in SCFToGPU: `scf.parallel` and `scf.reduce` op combination is now converted to a `gpu.all_reduce` op.
1 parent e069518 commit 0e944a3

File tree

2 files changed

+247
-2
lines changed

2 files changed

+247
-2
lines changed

mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp

+34-2
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ static LogicalResult processParallelLoop(
408408
ArrayAttr mapping =
409409
parallelOp->getAttrOfType<ArrayAttr>(gpu::getMappingAttrName());
410410

411-
// TODO: Support reductions.
412-
if (!mapping || parallelOp.getNumResults() != 0)
411+
// TODO: Support multiple reductions.
412+
if (!mapping || parallelOp.getNumResults() > 1)
413413
return failure();
414414

415415
Location loc = parallelOp.getLoc();
@@ -556,6 +556,11 @@ static LogicalResult processParallelLoop(
556556

557557
Block *body = parallelOp.getBody();
558558
worklist.reserve(worklist.size() + body->getOperations().size());
559+
// Include scf.reduce terminator if exists and has an operand.
560+
if (auto terminator = body->getTerminator();
561+
isa<scf::ReduceOp>(terminator) && terminator->getOperands().size() == 1) {
562+
worklist.push_back(terminator);
563+
}
559564
for (Operation &op : llvm::reverse(body->without_terminator()))
560565
worklist.push_back(&op);
561566
return success();
@@ -648,6 +653,33 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
648653
rewriter.setInsertionPointAfter(parent);
649654
leftNestingScope = true;
650655
seenSideeffects = false;
656+
} else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
657+
// Convert scf.reduction op
658+
auto parentLoop = op->getParentOfType<ParallelOp>();
659+
if (!parentLoop || op->getOperands().size() != 1)
660+
return failure();
661+
auto operand = op->getOperands().front();
662+
auto newValue = cloningMap.lookupOrNull(operand);
663+
if (!newValue || !operand.getType().isSignlessIntOrFloat())
664+
return failure();
665+
// Ensure reduction region is isolated from above.
666+
llvm::SetVector<Value> externalValues;
667+
getUsedValuesDefinedAbove(reduceOp.getRegion(0), externalValues);
668+
if (externalValues.size())
669+
return failure();
670+
// Replace by gpu.all_reduce.
671+
auto gpuRedOp = rewriter.create<gpu::AllReduceOp>(loc, newValue);
672+
cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult());
673+
// Copy region.
674+
rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(),
675+
gpuRedOp.getRegion().begin());
676+
// Replace src.reduce.return with gpu.yield.
677+
auto scfReturn = gpuRedOp.getRegion().front().getTerminator();
678+
auto ip = rewriter.saveInsertionPoint();
679+
rewriter.setInsertionPointToEnd(&gpuRedOp.getRegion().front());
680+
rewriter.replaceOpWithNewOp<gpu::YieldOp>(
681+
scfReturn, scfReturn->getOperands().front());
682+
rewriter.restoreInsertionPoint(ip);
651683
} else {
652684
// Otherwise we copy it over.
653685
Operation *clone = rewriter.clone(*op, cloningMap);

mlir/test/Conversion/SCFToGPU/parallel_loop.mlir

+213
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,216 @@ func.func @step_invariant() {
428428
// CHECK: %[[rhs:.*]] = memref.load %[[alloc_1]][%[[dim0]], %[[dim1]]] : memref<1x1xf64>
429429
// CHECK: %[[sum:.*]] = arith.addf %[[lhs]], %[[rhs]] : f64
430430
// CHECK: memref.store %[[sum]], %[[alloc_0]][%[[dim0]], %[[dim1]]] : memref<1x1xf64>
431+
432+
// -----
433+
434+
// 1-d parallel reduction mapped to block.x and thread.x.
435+
436+
// CHECK-LABEL: @parallel_reduction_1d
437+
func.func @parallel_reduction_1d() {
438+
%alloc = memref.alloc() : memref<f32>
439+
%alloc_0 = memref.alloc() : memref<64xf32>
440+
%c1 = arith.constant 1 : index
441+
%c64 = arith.constant 64 : index
442+
%c0 = arith.constant 0 : index
443+
%cst = arith.constant 0.000000e+00 : f32
444+
scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
445+
%0 = scf.parallel (%arg2) = (%c0) to (%c64) step (%c1) init (%cst) -> f32 {
446+
%1 = memref.load %alloc_0[%arg2] : memref<64xf32>
447+
scf.reduce(%1 : f32) {
448+
^bb0(%arg3: f32, %arg4: f32):
449+
%2 = arith.addf %arg3, %arg4 : f32
450+
scf.reduce.return %2 : f32
451+
}
452+
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
453+
memref.store %0, %alloc[] : memref<f32>
454+
scf.reduce
455+
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
456+
memref.dealloc %alloc : memref<f32>
457+
memref.dealloc %alloc_0 : memref<64xf32>
458+
return
459+
}
460+
461+
// CHECK: %[[alloc_0:.*]] = memref.alloc() : memref<f32>
462+
// CHECK: %[[alloc_1:.*]] = memref.alloc() : memref<64xf32>
463+
// CHECK: %[[map_0:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
464+
// CHECK: %[[map_1:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
465+
// CHECK: gpu.launch
466+
// CHECK-SAME: blocks(%[[arg_0:.*]], %{{[^)]*}}, %{{[^)]*}}) in (%{{[^)]*}} = %[[map_0]], %{{[^)]*}} = %{{[^)]*}}, %{{[^)]*}} = %{{[^)]*}})
467+
// CHECK-SAME: threads(%[[arg_3:.*]], %{{[^)]*}}, %{{[^)]*}}) in (%{{[^)]*}} = %[[map_1]], %{{[^)]*}} = %{{[^)]*}}, %{{[^)]*}} = %{{[^)]*}})
468+
// CHECK-NEXT: %[[dim0:.*]] = affine.apply #map1(%[[arg_0]])[{{.*}}, {{.*}}]
469+
// CHECK-NEXT: %[[dim1:.*]] = affine.apply #map1(%[[arg_3]])[{{.*}}, {{.*}}]
470+
// CHECK-NEXT: %[[src:.*]] = memref.load %[[alloc_1]][%[[dim1]]] : memref<64xf32>
471+
// CHECK-NEXT: %[[res:.*]] = gpu.all_reduce %[[src]] {
472+
// CHECK-NEXT: ^bb0(%[[arg12:.*]]: f32, %[[arg13:.*]]: f32):
473+
// CHECK-NEXT: %[[sum:.*]] = arith.addf %[[arg12]], %[[arg13]] : f32
474+
// CHECK-NEXT: gpu.yield %[[sum]] : f32
475+
// CHECK-NEXT: } : (f32) -> f32
476+
// CHECK-NEXT: memref.store %[[res]], %[[alloc_0]][] : memref<f32>
477+
478+
// -----
479+
480+
// 2-d parallel reduction mapped to block.x and thread.x and thread.y.
481+
482+
// CHECK-LABEL: @parallel_reduction_2d
483+
func.func @parallel_reduction_2d() {
484+
%alloc = memref.alloc() : memref<f32>
485+
%alloc_0 = memref.alloc() : memref<8x8xf32>
486+
%c1 = arith.constant 1 : index
487+
%c8 = arith.constant 8 : index
488+
%c0 = arith.constant 0 : index
489+
%cst = arith.constant 0.000000e+00 : f32
490+
scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
491+
%0 = scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) init (%cst) -> f32 {
492+
%1 = memref.load %alloc_0[%arg2, %arg3] : memref<8x8xf32>
493+
scf.reduce(%1 : f32) {
494+
^bb0(%arg4: f32, %arg5: f32):
495+
%2 = arith.addf %arg4, %arg5 : f32
496+
scf.reduce.return %2 : f32
497+
}
498+
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>, #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
499+
memref.store %0, %alloc[] : memref<f32>
500+
scf.reduce
501+
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
502+
memref.dealloc %alloc : memref<f32>
503+
memref.dealloc %alloc_0 : memref<8x8xf32>
504+
return
505+
}
506+
507+
// CHECK: %[[alloc_0:.*]] = memref.alloc() : memref<f32>
508+
// CHECK: %[[alloc_1:.*]] = memref.alloc() : memref<8x8xf32>
509+
// CHECK: %[[map_0:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
510+
// CHECK: %[[map_1:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
511+
// CHECK: %[[map_2:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
512+
// CHECK: gpu.launch
513+
// CHECK-SAME: blocks(%[[arg_0:.*]], %{{[^)]*}}, %{{[^)]*}}) in (%{{[^)]*}} = %[[map_0]], %{{[^)]*}} = %{{[^)]*}}, %{{[^)]*}} = %{{[^)]*}})
514+
// CHECK-SAME: threads(%[[arg_3:.*]], %[[arg_4:.*]], %{{[^)]*}}) in (%{{[^)]*}} = %[[map_1]], %{{[^)]*}} = %[[map_2]], %{{[^)]*}} = %{{[^)]*}})
515+
// CHECK-NEXT: %[[dim0:.*]] = affine.apply #map1(%[[arg_0]])[{{.*}}, {{.*}}]
516+
// CHECK-NEXT: %[[dim1:.*]] = affine.apply #map1(%[[arg_3]])[{{.*}}, {{.*}}]
517+
// CHECK-NEXT: %[[dim2:.*]] = affine.apply #map1(%[[arg_4]])[{{.*}}, {{.*}}]
518+
// CHECK-NEXT: %[[src:.*]] = memref.load %[[alloc_1]][%[[dim1]], %[[dim2]]] : memref<8x8xf32>
519+
// CHECK-NEXT: %[[res:.*]] = gpu.all_reduce %[[src]] {
520+
// CHECK-NEXT: ^bb0(%[[arg12:.*]]: f32, %[[arg13:.*]]: f32):
521+
// CHECK-NEXT: %[[sum:.*]] = arith.addf %[[arg12]], %[[arg13]] : f32
522+
// CHECK-NEXT: gpu.yield %[[sum]] : f32
523+
// CHECK-NEXT: } : (f32) -> f32
524+
// CHECK-NEXT: memref.store %[[res]], %[[alloc_0]][] : memref<f32>
525+
526+
// -----
527+
528+
// tiled 1-d parallel reduction mapped to block.x and thread.x.
529+
530+
// CHECK-LABEL: @parallel_reduction_1d_tiled
531+
func.func @parallel_reduction_1d_tiled() {
532+
%c128 = arith.constant 128 : index
533+
%c1 = arith.constant 1 : index
534+
%c64 = arith.constant 64 : index
535+
%c0 = arith.constant 0 : index
536+
%cst = arith.constant 0.000000e+00 : f32
537+
%alloc_0 = memref.alloc() : memref<8192xf32>
538+
%alloc_1 = memref.alloc() : memref<64xf32>
539+
scf.parallel (%arg1) = (%c0) to (%c64) step (%c1) {
540+
%subview = memref.subview %alloc_1[%arg1] [1] [1] : memref<64xf32> to memref<f32, strided<[], offset: ?>>
541+
%0 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg1)
542+
%subview_1 = memref.subview %alloc_0[%0] [128] [1] : memref<8192xf32> to memref<128xf32, strided<[1], offset: ?>>
543+
%1 = scf.parallel (%arg2) = (%c0) to (%c128) step (%c1) init (%cst) -> f32 {
544+
%2 = memref.load %subview_1[%arg2] : memref<128xf32, strided<[1], offset: ?>>
545+
scf.reduce(%2 : f32) {
546+
^bb0(%arg3: f32, %arg4: f32):
547+
%3 = arith.addf %arg3, %arg4 : f32
548+
scf.reduce.return %3 : f32
549+
}
550+
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
551+
memref.store %1, %subview[] : memref<f32, strided<[], offset: ?>>
552+
scf.reduce
553+
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
554+
memref.dealloc %alloc_0 : memref<8192xf32>
555+
memref.dealloc %alloc_1 : memref<64xf32>
556+
return
557+
}
558+
559+
// CHECK: %[[alloc_0:.*]] = memref.alloc() : memref<8192xf32>
560+
// CHECK: %[[alloc_1:.*]] = memref.alloc() : memref<64xf32>
561+
// CHECK: %[[map_0:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
562+
// CHECK: %[[map_1:.*]] = affine.apply #map({{.*}})[{{.*}}, {{.*}}]
563+
// CHECK: gpu.launch
564+
// CHECK-SAME: blocks(%[[arg_0:.*]], %{{[^)]*}}, %{{[^)]*}}) in (%{{[^)]*}} = %[[map_0]], %{{[^)]*}} = %{{[^)]*}}, %{{[^)]*}} = %{{[^)]*}})
565+
// CHECK-SAME: threads(%[[arg_3:.*]], %{{[^)]*}}, %{{[^)]*}}) in (%{{[^)]*}} = %[[map_1]], %{{[^)]*}} = %{{[^)]*}}, %{{[^)]*}} = %{{[^)]*}})
566+
// CHECK-NEXT: %[[dim0:.*]] = affine.apply #map1(%[[arg_0]])[{{.*}}, {{.*}}]
567+
// CHECK-NEXT: %[[dst:.*]] = memref.subview %[[alloc_1]][%[[dim0]]] [1] [1] : memref<64xf32>
568+
// CHECK-NEXT: %[[dim1:.*]] = affine.apply #map2(%[[dim0]])
569+
// CHECK-NEXT: %[[tile:.*]] = memref.subview %[[alloc_0]][%[[dim1]]] [128] [1] : memref<8192xf32>
570+
// CHECK-NEXT: %[[dim2:.*]] = affine.apply #map1(%[[arg_3]])[{{.*}}, {{.*}}]
571+
// CHECK-NEXT: %[[src:.*]] = memref.load %[[tile]][%[[dim2]]] : memref<128xf32, strided<[1], offset: ?>>
572+
// CHECK-NEXT: %[[res:.*]] = gpu.all_reduce %[[src]] {
573+
// CHECK-NEXT: ^bb0(%[[arg12:.*]]: f32, %[[arg13:.*]]: f32):
574+
// CHECK-NEXT: %[[sum:.*]] = arith.addf %[[arg12]], %[[arg13]] : f32
575+
// CHECK-NEXT: gpu.yield %[[sum]] : f32
576+
// CHECK-NEXT: } : (f32) -> f32
577+
// CHECK-NEXT: memref.store %[[res]], %[[dst]][] : memref<f32, strided<[], offset: ?>>
578+
579+
// -----
580+
581+
// 1-d parallel reduction, unsigned int. Cannot be mapped.
582+
583+
// CHECK-LABEL: @parallel_reduction_1d_uint
584+
func.func @parallel_reduction_1d_uint(%cst : ui32) {
585+
%alloc = memref.alloc() : memref<ui32>
586+
%alloc_0 = memref.alloc() : memref<64xui32>
587+
%c1 = arith.constant 1 : index
588+
%c64 = arith.constant 64 : index
589+
%c0 = arith.constant 0 : index
590+
scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
591+
%0 = scf.parallel (%arg2) = (%c0) to (%c64) step (%c1) init (%cst) -> ui32 {
592+
%1 = memref.load %alloc_0[%arg2] : memref<64xui32>
593+
scf.reduce(%1 : ui32) {
594+
^bb0(%arg3: ui32, %arg4: ui32):
595+
scf.reduce.return %arg3 : ui32
596+
}
597+
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
598+
memref.store %0, %alloc[] : memref<ui32>
599+
scf.reduce
600+
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
601+
memref.dealloc %alloc : memref<ui32>
602+
memref.dealloc %alloc_0 : memref<64xui32>
603+
return
604+
}
605+
606+
// CHECK: scf.parallel
607+
// CHECK-NEXT: scf.parallel
608+
// CHECK: scf.reduce
609+
610+
// -----
611+
612+
// 1-d parallel reduction, not isolated from above. Cannot be mapped.
613+
614+
// CHECK-LABEL: @parallel_reduction_1d_outside
615+
func.func @parallel_reduction_1d_outside() {
616+
%alloc = memref.alloc() : memref<f32>
617+
%alloc_0 = memref.alloc() : memref<64xf32>
618+
%c1 = arith.constant 1 : index
619+
%c64 = arith.constant 64 : index
620+
%c0 = arith.constant 0 : index
621+
%cst = arith.constant 0.000000e+00 : f32
622+
%const = arith.constant 1.000000e+00 : f32
623+
scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
624+
%0 = scf.parallel (%arg2) = (%c0) to (%c64) step (%c1) init (%cst) -> f32 {
625+
%1 = memref.load %alloc_0[%arg2] : memref<64xf32>
626+
scf.reduce(%1 : f32) {
627+
^bb0(%arg3: f32, %arg4: f32):
628+
%2 = arith.addf %arg3, %arg4 : f32
629+
%3 = arith.addf %2, %const : f32
630+
scf.reduce.return %3 : f32
631+
}
632+
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
633+
memref.store %0, %alloc[] : memref<f32>
634+
scf.reduce
635+
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
636+
memref.dealloc %alloc : memref<f32>
637+
memref.dealloc %alloc_0 : memref<64xf32>
638+
return
639+
}
640+
641+
// CHECK: scf.parallel
642+
// CHECK-NEXT: scf.parallel
643+
// CHECK: scf.reduce

0 commit comments

Comments
 (0)