Skip to content

[MLIR][OpenMP] Normalize handling of entry block arguments #109808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Sep 24, 2024

This patch introduces a new MLIR interface for the OpenMP dialect aimed at providing a uniform way of verifying and handling entry block arguments defined by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the number of block arguments the operation holds regarding each of the clauses that may define them. These by default return 0, but they are overriden by the corresponding clause through the extraClassDeclaration mechanism.

Another set of interface methods to get the actual lists of block arguments is defined, which is implemented based on the previously described methods. These implicitly define a standardized ordering between the list of block arguments associated to each clause, based on the alphabetical ordering of their names. They should be the preferred way of matching operation arguments and entry block arguments to that operation's first region.

Some updates are made to the printing/parsing of omp.parallel to follow the expected order between private and reduction clauses, as well as the MLIR to LLVM IR translation pass to access block arguments using the new interface. Unit tests of operations impacted by additional verification checks and sorting of entry block arguments.

@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a new MLIR interface for the OpenMP dialect aimed at providing a uniform way of verifying and handling entry block arguments defined by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the number of block arguments the operation holds regarding each of the clauses that may define them. These by default return 0, but they are overriden by the corresponding clause through the extraClassDeclaration mechanism.

Another set of interface methods to get the actual lists of block arguments is defined, which is implemented based on the previously described methods. These implicitly define a standardized ordering between the list of block arguments associated to each clause, based on the alphabetical ordering of their names. They should be the preferred way of matching operation arguments and entry block arguments to that operation's first region.

Some updates are made to the printing/parsing of omp.parallel to follow the expected order between private and reduction clauses, as well as the MLIR to LLVM IR translation pass to access block arguments using the new interface. Unit tests of operations impacted by additional verification checks and sorting of entry block arguments.


Patch is 32.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109808.diff

11 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+19-10)
  • (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 (+2-2)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+28-11)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+6-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+80)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+17-17)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+19-26)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+4)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17-6)
  • (modified) mlir/test/Target/LLVMIR/openmp-private.mlir (+1-1)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 960286732c90c2..e9095d631beb7b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
 /// \param [in] infoAccessor       - for a private variable, this returns the
 /// data we want to merge: type or location.
 /// \param [out] allRegionArgsInfo - the merged list of region info.
+/// \param [in] addBeforePrivate - `true` if the passed information goes before
+/// private information.
 template <typename OMPOp, typename InfoTy>
 static void
 mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
                      llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
-                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
+                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
+                     bool addBeforePrivate) {
   mlir::OperandRange privateVars = op.getPrivateVars();
 
-  llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
-                  [](InfoTy i) { return i; });
+  if (addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
+
   llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
                   infoAccessor);
+
+  if (!addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
   mergePrivateVarsInfo(targetOp, mapSymTypes,
                        llvm::function_ref<mlir::Type(mlir::Value)>{
                            [](mlir::Value v) { return v.getType(); }},
-                       allRegionArgTypes);
+                       allRegionArgTypes, /*addBeforePrivate=*/true);
 
   mergePrivateVarsInfo(targetOp, mapSymLocs,
                        llvm::function_ref<mlir::Location(mlir::Value)>{
                            [](mlir::Value v) { return v.getLoc(); }},
-                       allRegionArgLocs);
+                       allRegionArgLocs, /*addBeforePrivate=*/true);
 
   mlir::Block *regionBlock = firOpBuilder.createBlock(
       &region, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
     mergePrivateVarsInfo(parallelOp, reductionTypes,
                          llvm::function_ref<mlir::Type(mlir::Value)>{
                              [](mlir::Value v) { return v.getType(); }},
-                         allRegionArgTypes);
+                         allRegionArgTypes, /*addBeforePrivate=*/false);
 
     llvm::SmallVector<mlir::Location> allRegionArgLocs;
     mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
                          llvm::function_ref<mlir::Location(mlir::Value)>{
                              [](mlir::Value v) { return v.getLoc(); }},
-                         allRegionArgLocs);
+                         allRegionArgLocs, /*addBeforePrivate=*/false);
 
     mlir::Region &region = parallelOp.getRegion();
     firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
                              allRegionArgLocs);
 
-    llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
-    allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
-                      dsp->getDelayedPrivSymbols().end());
+    llvm::SmallVector<const semantics::Symbol *> allSymbols(
+        dsp->getDelayedPrivSymbols());
+    allSymbols.append(reductionSyms.begin(), reductionSyms.end());
 
     unsigned argIdx = 0;
     for (const semantics::Symbol *arg : allSymbols) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 29439571179322..6c00bb23f15b96 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
index d814b2b0ff0f31..38139e52ce95cb 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c579ba6e751d2b..876d53766a0ca1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
       return SmallVector<Value>(getInReductionVars().begin(),
                                 getInReductionVars().end());
     }
+
+    unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
+    // Not adding the BlockArgOpenMPOpInterface here because omp.target is the
+    // only operation defining block arguments for `map` clauses.
     MapClauseOwningOpInterface
   ];
 
@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
     OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
       custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
+  }];
+
   // TODO: Add description.
 }
 
@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
   let extraClassDeclaration = [{
     /// Returns the number of reduction variables.
     unsigned getNumReductionVars() { return getReductionVars().size(); }
+    unsigned numReductionBlockArgs() { return getReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
                                $task_reduction_byref, $task_reduction_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    /// Returns the reduction variables.
+    SmallVector<Value> getReductionVars() {
+      return SmallVector<Value>(getTaskReductionVars().begin(),
+                                getTaskReductionVars().end());
+    }
+
+    unsigned numTaskReductionBlockArgs() {
+      return getTaskReductionVars().size();
+    }
+  }];
+
   let description = [{
     The `task_reduction` clause specifies a reduction among tasks. For each list
     item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
     attribute, and whether the reduction variable should be passed into the
     reduction region by value or by reference in `task_reduction_byref`.
   }];
-
-  let extraClassDeclaration = [{
-    /// Returns the reduction variables.
-    SmallVector<Value> getReductionVars() {
-      return SmallVector<Value>(getTaskReductionVars().begin(),
-                                getTaskReductionVars().end());
-    }
-  }];
 }
 
 def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..326bdd3bbc9463 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
 //===----------------------------------------------------------------------===//
 
 def TargetOp : OpenMP_Op<"target", traits = [
-    AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
+    AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+    OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
 
+  let extraClassDeclaration = [{
+    unsigned numMapBlockArgs() { return getMapVars().size(); }
+  }] # clausesExtraClassDeclaration;
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0078e22b1c89a6..030075eaf45b14 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,86 @@
 
 include "mlir/IR/OpBase.td"
 
+def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
+  let description = [{
+    OpenMP operations that define entry block arguments as part of the
+    representation of its clauses.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
+                    "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `map`.",
+                    "unsigned", "numMapBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `private`.",
+                    "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `reduction`.",
+                    "unsigned", "numReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
+                    "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get block arguments defined by `in_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getInReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().take_front(
+          $_op.numInReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `map`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getMapBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `private`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getPrivateBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
+          $_op.numPrivateBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `task_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getTaskReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
+          $_op.numTaskReductionBlockArgs());
+    }]>,
+  ];
+
+  let verify = [{
+    auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
+    unsigned expectedArgs = iface.numInReductionBlockArgs() +
+        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
+        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+    if ($_op->getRegion(0).getNumArguments() < expectedArgs)
+      return $_op->emitOpError() << "expected at least " << expectedArgs
+                                 << " entry block argument(s)";
+    return ::mlir::success();
+  }];
+}
+
 def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
   let description = [{
     OpenMP operations whose region will be outlined will implement this
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db47276dcefe95..7ca7a2afbdbbc4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
 
-  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
-    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
-                                         reductionTypes, reductionByref,
-                                         reductionSyms, regionPrivateArgs)))
-      return failure();
-  }
-
   if (succeeded(parser.parseOptionalKeyword("private"))) {
     auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
     if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
     }
   }
 
+  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
+                                         reductionTypes, reductionByref,
+                                         reductionSyms, regionPrivateArgs)))
+      return failure();
+  }
+
   return parser.parseRegion(region, regionPrivateArgs);
 }
 
@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                 DenseBoolArrayAttr reductionByref,
                                 ArrayAttr reductionSyms, ValueRange privateVars,
                                 TypeRange privateTypes, ArrayAttr privateSyms) {
-  if (reductionSyms) {
-    auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
-    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
-                              reductionTypes, reductionByref, reductionSyms);
-  }
-
   if (privateSyms) {
     auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
-                                 argsBegin + reductionVars.size() +
-                                     privateTypes.size());
+    MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
     mlir::SmallVector<bool> isByRefVec;
     isByRefVec.resize(privateTypes.size(), false);
     DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                               privateTypes, isByRef, privateSyms);
   }
 
+  if (reductionSyms) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
+                                 argsBegin + privateVars.size() +
+                                     reductionTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
+                              reductionTypes, reductionByref, reductionSyms);
+  }
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0cba8d80681f13..769cdd57656b51 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      sectionsOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -954,8 +954,10 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
       // variables
       assert(region.getNumArguments() ==
              sectionsOp.getRegion().getNumArguments());
-      for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
-               sectionsOp.getRegion().getArguments(), region.getArguments())) {
+      for (auto [sectionsArg, sectionArg] :
+           llvm::zip_equal(cast<omp::BlockArgOpenMPOpInterface>(*sectionsOp)
+                               .getReductionBlockArgs(),
+                           region.getArguments())) {
         llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
         assert(llvmVal);
         moduleTranslation.mapValue(sectionArg, llvmVal);
@@ -1216,7 +1218,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      wsloopOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1329,31 +1331,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 class OmpParallelOpConversionManager {
 public:
   OmpParallelOpConversionManager(omp::ParallelOp opInst)
-      : region(opInst.getRegion()), privateVars(opInst.ge...
[truncated]

Copy link
Contributor

@bhandarkar-pranav bhandarkar-pranav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR. The only suggestion I have is to make it explicit in the docs that the order of the clauses in defining the operations matters

@skatrak
Copy link
Member Author

skatrak commented Sep 26, 2024

Thank you for this PR. The only suggestion I have is to make it explicit in the docs that the order of the clauses in defining the operations matters

Thank you for taking a look, PR #109811 in the stack should address this. Feel free to add any comments there if you find something is not clear.

@skatrak skatrak force-pushed the users/skatrak/block-args-01-iface branch from 82a0c88 to 4555bc4 Compare September 30, 2024 11:43
@bhandarkar-pranav
Copy link
Contributor

Thank you for this PR. The only suggestion I have is to make it explicit in the docs that the order of the clauses in defining the operations matters

Thank you for taking a look, PR #109811 in the stack should address this. Feel free to add any comments there if you find something is not clear.

Sorry for the delay. I had to take a couple of days off unexpectedly.
Sounds good. LGTM, this PR.

This patch introduces a new MLIR interface for the OpenMP dialect aimed at
providing a uniform way of verifying and handling entry block arguments defined
by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the
number of block arguments the operation holds regarding each of the clauses
that may define them. These by default return 0, but they are overriden by the
corresponding clause through the `extraClassDeclaration` mechanism.

Another set of interface methods to get the actual lists of block arguments is
defined, which is implemented based on the previously described methods. These
implicitly define a standardized ordering between the list of block arguments
associated to each clause, based on the alphabetical ordering of their names.
They should be the preferred way of matching operation arguments and entry
block arguments to that operation's first region.

Some updates are made to the printing/parsing of `omp.parallel` to follow the
expected order between `private` and `reduction` clauses, as well as the MLIR
to LLVM IR translation pass to access block arguments using the new interface.
Unit tests of operations impacted by additional verification checks and
sorting of entry block arguments.
@skatrak skatrak force-pushed the users/skatrak/block-args-01-iface branch from 4555bc4 to e5a0b3c Compare October 1, 2024 13:29
@skatrak skatrak merged commit d0f6777 into main Oct 1, 2024
6 of 7 checks passed
@skatrak skatrak deleted the users/skatrak/block-args-01-iface branch October 1, 2024 14:04
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
This patch introduces a new MLIR interface for the OpenMP dialect aimed
at providing a uniform way of verifying and handling entry block
arguments defined by OpenMP clauses.

The approach consists in defining a set of overrideable methods that
return the number of block arguments the operation holds regarding each
of the clauses that may define them. These by default return 0, but they
are overriden by the corresponding clause through the
`extraClassDeclaration` mechanism.

Another set of interface methods to get the actual lists of block
arguments is defined, which is implemented based on the previously
described methods. These implicitly define a standardized ordering
between the list of block arguments associated to each clause, based on
the alphabetical ordering of their names. They should be the preferred
way of matching operation arguments and entry block arguments to that
operation's first region.

Some updates are made to the printing/parsing of `omp.parallel` to
follow the expected order between `private` and `reduction` clauses, as
well as the MLIR to LLVM IR translation pass to access block arguments
using the new interface. Unit tests of operations impacted by additional
verification checks and sorting of entry block arguments.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants