-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][scf] Implement conversion from scf.forall to scf.parallel #94109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Spenser Bauman (sabauma) ChangesThere is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect. In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult. This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass. Full diff: https://github.com/llvm/llvm-project/pull/94109.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..3d7fe7b0f093f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
+def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Converts scf.forall into a nest of scf.for operations";
+ let description = [{
+ Converts the `scf.forall` operation pointed to by the given handle into an
+ `scf.parallel` operation.
+
+ The operand handle must be associated with exactly one payload operation.
+
+ Loops with outputs are not supported.
+
+ #### Return Modes
+
+ Consumes the operand handle. Produces a silenceable failure if the operand
+ is not associated with a single `scf.forall` payload operation.
+ Returns a handle to the new `scf.parallel` operation.
+ Produces a silenceable failure if another number of resulting handles is
+ requested.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 31c3d0eb629d2..fb8411418ff9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createForallToParallelLoopPass();
+
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index a7aeb42d60c0e..9b29affb97c43 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}
+def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
+ let summary = "Convert SCF forall loops to SCF parallel loops";
+ let constructor = "mlir::createForallToParallelLoopPass()";
+}
+
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index b063e6e775e63..02ad7ab135f86 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -39,6 +39,12 @@ class WhileOp;
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult forallToParallelLoop(RewriterBase &rewriter,
+ ForallOp forallOp,
+ ParallelOp *result = nullptr);
+
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..16f1db44acc35 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
- Location loc = forallOp.getLoc();
- if (!forallOp.getOutputs().empty())
- return rewriter.notifyMatchFailure(
- forallOp,
- "only fully bufferized scf.forall ops can be lowered to scf.parallel");
-
- // Convert mixed bounds and steps to SSA values.
- SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedLowerBound());
- SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedUpperBound());
- SmallVector<Value> steps =
- getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
-
- // Create empty scf.parallel op.
- auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
- rewriter.eraseBlock(¶llelOp.getRegion().front());
- rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
- parallelOp.getRegion().begin());
- // Replace the terminator.
- rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
- rewriter.replaceOpWithNewOp<scf::ReduceOp>(
- parallelOp.getRegion().front().getTerminator());
-
- // Erase the scf.forall op.
- rewriter.replaceOp(forallOp, parallelOp);
- return success();
+ return scf::forallToParallelLoop(rewriter, forallOp);
}
void mlir::populateSCFToControlFlowConversionPatterns(
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..7bdf3ac1d6ac3 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload))
+ return emitSilenceableError() << "expected a single payload op";
+
+ auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.forall";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ if (!target.getOutputs().empty()) {
+ return emitSilenceableError()
+ << "unsupported shared outputs (didn't bufferize?)";
+ }
+
+ if (getNumResults() != 1) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "op expects one result, given "
+ << getNumResults();
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ scf::ParallelOp opResult;
+ if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to convert forall into parallel";
+ return diag;
+ }
+
+ results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e7671c9cc28f8..d363ffe941fce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
+ ForallToParallel.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
new file mode 100644
index 0000000000000..8882d083635f2
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -0,0 +1,83 @@
+//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ParallelOps's.
+//
+//===----------------------------------------------------------------------===//
+
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
+ scf::ForallOp forallOp,
+ scf::ParallelOp *result) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ Location loc = forallOp.getLoc();
+ if (!forallOp.getOutputs().empty())
+ return rewriter.notifyMatchFailure(
+ forallOp,
+ "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+ // Convert mixed bounds and steps to SSA values.
+ SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedLowerBound());
+ SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedUpperBound());
+ SmallVector<Value> steps =
+ getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+ // Create empty scf.parallel op.
+ auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
+ rewriter.eraseBlock(¶llelOp.getRegion().front());
+ rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+ parallelOp.getRegion().begin());
+ // Replace the terminator.
+ rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
+ rewriter.replaceOpWithNewOp<scf::ReduceOp>(
+ parallelOp.getRegion().front().getTerminator());
+
+ // Erase the scf.forall op.
+ rewriter.replaceOp(forallOp, parallelOp);
+
+ if (result)
+ *result = parallelOp;
+
+ return success();
+}
+
+namespace {
+struct ForallToParallelLoop final
+ : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ IRRewriter rewriter(parentOp->getContext());
+
+ parentOp->walk([&](scf::ForallOp forallOp) {
+ if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
+ return signalPassFailure();
+ }
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
+ return std::make_unique<ForallToParallelLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
new file mode 100644
index 0000000000000..424ba01fc3a66
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV3]], %[[IV4]])
+ // CHECK: scf.reduce
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
+ // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+ // CHECK: scf.reduce
+ // CHECK: }
+ // CHECK: scf.reduce
+ // CHECK: }
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ scf.forall (%k, %l) in (%ub3, %ub4) {
+ func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
new file mode 100644
index 0000000000000..b64798e06a4d1
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected a single payload op}}
+ transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected the payload to be scf.forall}}
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect. In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult. This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering, but also in a separate -scf-forall-to-parallel pass.
163f463
to
68baef8
Compare
@KavithaTipturMadhu @adam-smnk this would replace our local pass in https://github.com/plaidml/tpp-mlir/blob/main/lib/TPP/Transforms/ConvertForAllToParallelOp.cpp |
LGTM. By the way, I am curious what will nested |
@Menooker The SCF->OpenMP lowering does not try to do anything clever with nested scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
"test.payload"(%i, %j) : (index, index) -> ()
}
}
omp.parallel {
omp.wsloop {
omp.loop_nest (%arg6) : index = (%arg0) to (%arg2) step (%arg4) {
memref.alloca_scope {
%1 = llvm.mlir.constant(1 : i64) : i64
omp.parallel {
omp.wsloop {
omp.loop_nest (%arg7) : index = (%arg1) to (%arg3) step (%arg5) {
memref.alloca_scope {
"test.payload"(%arg6, %arg7) : (index, index) -> ()
}
omp.yield
}
omp.terminator
}
omp.terminator
}
}
omp.yield
}
omp.terminator
}
omp.terminator
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks
There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect.
In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult.
This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass.