Skip to content

[mlir][scf] Implement conversion from scf.forall to scf.parallel #94109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Jun 1, 2024

There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect.

In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult.

This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass.

@llvmbot
Copy link
Member

llvmbot commented Jun 1, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Spenser Bauman (sabauma)

Changes

There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect.

In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult.

This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass.


Full diff: https://github.com/llvm/llvm-project/pull/94109.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+26)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+5)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+2-27)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+44)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp (+83)
  • (added) mlir/test/Dialect/SCF/forall-to-parallel.mlir (+62)
  • (added) mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir (+60)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..3d7fe7b0f093f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.forall into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.forall` operation pointed to by the given handle into an
+    `scf.parallel` operation.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with outputs are not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.forall` payload operation.
+    Returns a handle to the new `scf.parallel` operation.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 31c3d0eb629d2..fb8411418ff9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 /// Creates a pass that converts SCF forall loops to SCF for loops.
 std::unique_ptr<Pass> createForallToForLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createForallToParallelLoopPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index a7aeb42d60c0e..9b29affb97c43 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
   let constructor = "mlir::createForallToForLoopPass()";
 }
 
+def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
+  let summary = "Convert SCF forall loops to SCF parallel loops";
+  let constructor = "mlir::createForallToParallelLoopPass()";
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index b063e6e775e63..02ad7ab135f86 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -39,6 +39,12 @@ class WhileOp;
 LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
                               SmallVectorImpl<Operation *> *results = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult forallToParallelLoop(RewriterBase &rewriter,
+                                   ForallOp forallOp,
+                                   ParallelOp *result = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..16f1db44acc35 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
                                               PatternRewriter &rewriter) const {
-  Location loc = forallOp.getLoc();
-  if (!forallOp.getOutputs().empty())
-    return rewriter.notifyMatchFailure(
-        forallOp,
-        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
-
-  // Convert mixed bounds and steps to SSA values.
-  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedLowerBound());
-  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedUpperBound());
-  SmallVector<Value> steps =
-      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
-
-  // Create empty scf.parallel op.
-  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
-  rewriter.eraseBlock(&parallelOp.getRegion().front());
-  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
-                              parallelOp.getRegion().begin());
-  // Replace the terminator.
-  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
-  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
-      parallelOp.getRegion().front().getTerminator());
-
-  // Erase the scf.forall op.
-  rewriter.replaceOp(forallOp, parallelOp);
-  return success();
+  return scf::forallToParallelLoop(rewriter, forallOp);
 }
 
 void mlir::populateSCFToControlFlowConversionPatterns(
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..7bdf3ac1d6ac3 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
+                                     transform::TransformResults &results,
+                                     transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.forall";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (!target.getOutputs().empty()) {
+    return emitSilenceableError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  scf::ParallelOp opResult;
+  if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to convert forall into parallel";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e7671c9cc28f8..d363ffe941fce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ForallToFor.cpp
+  ForallToParallel.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
new file mode 100644
index 0000000000000..8882d083635f2
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -0,0 +1,83 @@
+//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ParallelOps's.
+//
+//===----------------------------------------------------------------------===//
+
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
+                                              scf::ForallOp forallOp,
+                                              scf::ParallelOp *result) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+
+  Location loc = forallOp.getLoc();
+  if (!forallOp.getOutputs().empty())
+    return rewriter.notifyMatchFailure(
+        forallOp,
+        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+  // Convert mixed bounds and steps to SSA values.
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+  // Create empty scf.parallel op.
+  auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
+  rewriter.eraseBlock(&parallelOp.getRegion().front());
+  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+                              parallelOp.getRegion().begin());
+  // Replace the terminator.
+  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
+  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
+      parallelOp.getRegion().front().getTerminator());
+
+  // Erase the scf.forall op.
+  rewriter.replaceOp(forallOp, parallelOp);
+
+  if (result)
+    *result = parallelOp;
+
+  return success();
+}
+
+namespace {
+struct ForallToParallelLoop final
+    : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ForallOp forallOp) {
+      if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
+  return std::make_unique<ForallToParallelLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
new file mode 100644
index 0000000000000..424ba01fc3a66
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV3]], %[[IV4]])
+  // CHECK:   scf.reduce
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
+  // CHECK:   scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+  // CHECK:     scf.reduce
+  // CHECK:   }
+  // CHECK:   scf.reduce
+  // CHECK: }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    scf.forall (%k, %l) in (%ub3, %ub4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  return
+}
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
new file mode 100644
index 0000000000000..b64798e06a4d1
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.forall}}
+    transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

Copy link

github-actions bot commented Jun 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@rengolin rengolin requested a review from adam-smnk June 1, 2024 12:44
There is currently no path to lower scf.forall to scf.parallel with the
goal of targeting the OpenMP dialect.

In the SCF->ControlFlow conversion, scf.forall is briefly converted to
scf.parallel, but the scf.parallel is lowered directly to a sequential
loop. This makes experimenting with scf.forall for CPU execution
difficult.

This change factors out the rewrite in the SCF->ControlFlow pass into
a utility function that can then be used in the SCF->ControlFlow
lowering, but also in a separate -scf-forall-to-parallel pass.
@sabauma sabauma force-pushed the forall-op-lowering branch from 163f463 to 68baef8 Compare June 1, 2024 12:51
@rengolin
Copy link
Member

rengolin commented Jun 1, 2024

@Menooker
Copy link

Menooker commented Jun 3, 2024

LGTM. By the way, I am curious what will nested scf.parallel be lowered to, in OpenMP target? Will it be nested parallelism in OMP?

@sabauma
Copy link
Contributor Author

sabauma commented Jun 3, 2024

LGTM. By the way, I am curious what will nested scf.parallel be lowered to, in OpenMP target? Will it be nested parallelism in OMP?

@Menooker The SCF->OpenMP lowering does not try to do anything clever with nested scf.parallel. The lowering produces nested omp.parallel/omp.wsloop sequences.

  scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
    scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
      "test.payload"(%i, %j) : (index, index) -> ()
    }
  }

  omp.parallel {
      omp.wsloop {
        omp.loop_nest (%arg6) : index = (%arg0) to (%arg2) step (%arg4) {
          memref.alloca_scope  {
            %1 = llvm.mlir.constant(1 : i64) : i64
            omp.parallel {
              omp.wsloop {
                omp.loop_nest (%arg7) : index = (%arg1) to (%arg3) step (%arg5) {
                  memref.alloca_scope  {
                    "test.payload"(%arg6, %arg7) : (index, index) -> ()
                  }
                  omp.yield
                }
                omp.terminator
              }
              omp.terminator
            }
          }
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
   }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks

@sabauma sabauma merged commit 0b665c3 into llvm:main Jun 4, 2024
7 checks passed
keith added a commit that referenced this pull request Jun 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants