Skip to content

[mlir][scf] Implement conversion from scf.forall to scf.parallel #94109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Converts scf.forall into a nest of scf.for operations";
let description = [{
Converts the `scf.forall` operation pointed to by the given handle into an
`scf.parallel` operation.

The operand handle must be associated with exactly one payload operation.

Loops with outputs are not supported.

#### Return Modes

Consumes the operand handle. Produces a silenceable failure if the operand
is not associated with a single `scf.forall` payload operation.
Returns a handle to the new `scf.parallel` operation.
Produces a silenceable failure if another number of resulting handles is
requested.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();

/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();

// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}

def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let summary = "Convert SCF forall loops to SCF parallel loops";
let constructor = "mlir::createForallToParallelLoopPass()";
}

def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class WhileOp;
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);

/// Try converting scf.forall into an scf.parallel loop.
/// The conversion is only supported for forall operations with no results.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
ParallelOp *result = nullptr);

/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
MLIRArithDialect
MLIRControlFlowDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRTransforms
)
29 changes: 2 additions & 27 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,

LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");

// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());

// Create empty scf.parallel op.
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());

// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
return success();
return scf::forallToParallelLoop(rewriter, forallOp);
}

void mlir::populateSCFToControlFlowConversionPatterns(
Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload))
return emitSilenceableError() << "expected a single payload op";

auto target = dyn_cast<scf::ForallOp>(*payload.begin());
if (!target) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected the payload to be scf.forall";
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
return diag;
}

if (!target.getOutputs().empty()) {
return emitSilenceableError()
<< "unsupported shared outputs (didn't bufferize?)";
}

if (getNumResults() != 1) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "op expects one result, given "
<< getNumResults();
diag.attachNote(target.getLoc()) << "payload op";
return diag;
}

scf::ParallelOp opResult;
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "failed to convert forall into parallel";
return diag;
}

results.set(cast<OpResult>(getTransformed()[0]), {opResult});
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
Expand Down
86 changes: 86 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

using namespace mlir;

LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
scf::ForallOp forallOp,
scf::ParallelOp *result) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);

Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");

// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());

// Create empty scf.parallel op.
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());

// If the mapping attribute is present, propagate to the new parallelOp.
if (forallOp.getMapping())
parallelOp->setAttr("mapping", *forallOp.getMapping());

// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);

if (result)
*result = parallelOp;

return success();
}

namespace {
struct ForallToParallelLoop final
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
return std::make_unique<ForallToParallelLoop>();
}
80 changes: 80 additions & 0 deletions mlir/test/Dialect/SCF/forall-to-parallel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
return
}

// -----

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @repeated
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
// CHECK: scf.reduce
return
}

// -----

func.func private @callee(%i: index, %j: index, %k: index, %l: index)

// CHECK-LABEL: @nested
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
// CHECK: scf.reduce
// CHECK: }
// CHECK: scf.reduce
// CHECK: }
scf.forall (%i, %j) in (%ub1, %ub2) {
scf.forall (%k, %l) in (%ub3, %ub4) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
}
}
return
}

// -----

// CHECK-LABEL: @mapping_attr
func.func @mapping_attr() -> () {
// CHECK: scf.parallel
// CHECK: scf.reduce
// CHECK: {mapping = [#gpu.thread<x>]}

%num_threads = arith.constant 100 : index

scf.forall (%thread_idx) in (%num_threads) {
scf.forall.in_parallel {
}
} {mapping = [#gpu.thread<x>]}
return

}
Loading
Loading