Skip to content

[MLIR][OpenMP] Use map format to represent use_device_{addr,ptr} #109810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Sep 24, 2024

This patch updates the omp.target_data operation to use the same formatting as map clauses on omp.target for use_device_addr and use_device_ptr. This is done so the mapping that is being enforced between op arguments and associated entry block arguments is explicit.

The way it is achieved is by marking these clauses as entry block argument-defining and adjusting printer/parsers accordingly.

As a result of this change, block arguments for use_device_addr come before those for use_device_ptr, which is the opposite of the previous undocumented situation. Some unit tests are updated based on this change, in addition to those updated because of the format change.

@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the omp.target_data operation to use the same formatting as map clauses on omp.target for use_device_addr and use_device_ptr. This is done so the mapping that is being enforced between op arguments and associated entry block arguments is explicit.

The way it is achieved is by marking these clauses as entry block argument-defining and adjusting printer/parsers accordingly.

As a result of this change, block arguments for use_device_addr come before those for use_device_ptr, which is the opposite of the previous undocumented situation. Some unit tests are updated based on this change, in addition to those updated because of the format change.


Patch is 32.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109810.diff

11 Files Affected:

  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+3-2)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+2-4)
  • (modified) flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 (+4-8)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+24-4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+6)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+27-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+43)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+52-28)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+4-2)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+7-12)
  • (modified) mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir (+1-2)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 4d226eaa754c12..61f18008633d50 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -429,13 +429,14 @@ func.func @_QPopenmp_target_data_region() {
 
 func.func @_QPomp_target_data_empty() {
   %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
-  omp.target_data use_device_addr(%0 : !fir.ref<!fir.array<1024xi32>>) {
+  omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+    omp.terminator
   }
   return
 }
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data   use_device_addr(%1 : !llvm.ptr) {
+// CHECK: omp.target_data   use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
 // CHECK: }
 
 // -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index dedce581436490..ab33b6b3808315 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -506,9 +506,8 @@ subroutine omp_target_device_ptr
    type(c_ptr) :: a
    integer, target :: b
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}})   map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
-   !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}})
+   !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}} -> %[[VAL_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
    !$omp target data map(tofrom: a) use_device_ptr(a)
-   !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
    !CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
       a = c_loc(b)
    !CHECK: omp.terminator
@@ -529,9 +528,8 @@ subroutine omp_target_device_addr
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
    !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
-   !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]], %[[DEV_ADDR]] : {{.*}}) {
+   !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
    !$omp target data map(tofrom: a) use_device_addr(a)
-   !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
    !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
    !CHECK: %[[C10:.*]] = arith.constant 10 : i32
    !CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index 085f5419fa7f88..cb26246a6e80f0 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,8 +6,7 @@
 ! use_device_ptr to use_device_addr works, without breaking any functionality.
 
 !CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
 subroutine only_use_device_ptr
     use iso_c_binding
     integer, pointer, dimension(:) :: array
@@ -19,8 +18,7 @@ subroutine only_use_device_ptr
      end subroutine
 
 !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
 subroutine mix_use_device_ptr_and_addr
     use iso_c_binding
     integer, pointer, dimension(:) :: array
@@ -32,8 +30,7 @@ subroutine mix_use_device_ptr_and_addr
      end subroutine
 
      !CHECK: func.func @{{.*}}only_use_device_addr()
-     !CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
-     !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+     !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
      subroutine only_use_device_addr
         use iso_c_binding
         integer, pointer, dimension(:) :: array
@@ -45,8 +42,7 @@ subroutine only_use_device_addr
      end subroutine
 
      !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
-     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
-     !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
      subroutine mix_use_device_ptr_and_addr_and_map
         use iso_c_binding
         integer :: i, j
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 97e8b368050725..886554f66afffc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1209,18 +1209,28 @@ class OpenMP_UseDeviceAddrClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
   );
 
-  let optAssemblyFormat = [{
-    `use_device_addr` `(` $use_device_addr_vars `:` type($use_device_addr_vars) `)`
+  let extraClassDeclaration = [{
+    unsigned numUseDeviceAddrBlockArgs() {
+      return getUseDeviceAddrVars().size();
+    }
   }];
 
   let description = [{
     The optional `use_device_addr_vars` specifies the address of the objects in
     the device data environment.
   }];
+
+  // Assembly format not defined because this clause must be processed together
+  // with the first region of the operation, as it defines entry block
+  // arguments.
 }
 
 def OpenMP_UseDeviceAddrClause : OpenMP_UseDeviceAddrClauseSkip<>;
@@ -1234,18 +1244,28 @@ class OpenMP_UseDevicePtrClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
   );
 
-  let optAssemblyFormat = [{
-    `use_device_ptr` `(` $use_device_ptr_vars `:` type($use_device_ptr_vars) `)`
+  let extraClassDeclaration = [{
+    unsigned numUseDevicePtrBlockArgs() {
+      return getUseDevicePtrVars().size();
+    }
   }];
 
   let description = [{
     The optional `use_device_ptr_vars` specifies the device pointers to the
     corresponding list items in the device data environment.
   }];
+
+  // Assembly format not defined because this clause must be processed together
+  // with the first region of the operation, as it defines entry block
+  // arguments.
 }
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index e58ccc4e930210..d2a2b44c042fb7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -987,6 +987,12 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
     OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
   ];
 
+  let assemblyFormat = clausesAssemblyFormat # [{
+    custom<UseDeviceAddrUseDevicePtrRegion>(
+        $region, $use_device_addr_vars, type($use_device_addr_vars),
+        $use_device_ptr_vars, type($use_device_ptr_vars)) attr-dict
+  }];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 030075eaf45b14..b16fc1ba1c3105 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -45,6 +45,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
                     "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
       return 0;
     }]>,
+    InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
+                    "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
+                    "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
 
     // Unified access methods for clause-associated entry block arguments.
     InterfaceMethod<"Get block arguments defined by `in_reduction`.",
@@ -81,13 +89,31 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
           $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
           $_op.numTaskReductionBlockArgs());
     }]>,
+    InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getUseDeviceAddrBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs() +
+          $_op.numTaskReductionBlockArgs(), $_op.numUseDeviceAddrBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getUseDevicePtrBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs() +
+          $_op.numTaskReductionBlockArgs() + $_op.numUseDeviceAddrBlockArgs(),
+          $_op.numUseDevicePtrBlockArgs());
+    }]>,
   ];
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
     unsigned expectedArgs = iface.numInReductionBlockArgs() +
         iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
-        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
+        iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4d2e6152a4fe4c..e4fe67726ab5fb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -504,6 +504,8 @@ struct AllRegionParseArgs {
   std::optional<PrivateParseArgs> privateArgs;
   std::optional<ReductionParseArgs> reductionArgs;
   std::optional<ReductionParseArgs> taskReductionArgs;
+  std::optional<MapParseArgs> useDeviceAddrArgs;
+  std::optional<MapParseArgs> useDevicePtrArgs;
 };
 } // namespace
 
@@ -648,6 +650,16 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
     return parser.emitError(parser.getCurrentLocation())
            << "invalid `task_reduction` format";
 
+  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
+                                 args.useDeviceAddrArgs)))
+    return parser.emitError(parser.getCurrentLocation())
+           << "invalid `use_device_addr` format";
+
+  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
+                                 args.useDevicePtrArgs)))
+    return parser.emitError(parser.getCurrentLocation())
+           << "invalid `use_device_addr` format";
+
   return parser.parseRegion(region, entryBlockArgs);
 }
 
@@ -735,6 +747,18 @@ static ParseResult parseTaskReductionRegion(
   return parseBlockArgRegion(parser, region, args);
 }
 
+static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
+    OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
+    SmallVectorImpl<Type> &useDeviceAddrTypes,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
+    SmallVectorImpl<Type> &useDevicePtrTypes) {
+  AllRegionParseArgs args;
+  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
+  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
+  return parseBlockArgRegion(parser, region, args);
+}
+
 //===----------------------------------------------------------------------===//
 // Printers for operations including clauses that define entry block arguments.
 //===----------------------------------------------------------------------===//
@@ -767,6 +791,8 @@ struct AllRegionPrintArgs {
   std::optional<PrivatePrintArgs> privateArgs;
   std::optional<ReductionPrintArgs> reductionArgs;
   std::optional<ReductionPrintArgs> taskReductionArgs;
+  std::optional<MapPrintArgs> useDeviceAddrArgs;
+  std::optional<MapPrintArgs> useDevicePtrArgs;
 };
 } // namespace
 
@@ -849,6 +875,11 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   printBlockArgClause(p, ctx, "task_reduction",
                       iface.getTaskReductionBlockArgs(),
                       args.taskReductionArgs);
+  printBlockArgClause(p, ctx, "use_device_addr",
+                      iface.getUseDeviceAddrBlockArgs(),
+                      args.useDeviceAddrArgs);
+  printBlockArgClause(p, ctx, "use_device_ptr",
+                      iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
 
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
@@ -925,6 +956,18 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
   printBlockArgRegion(p, op, region, args);
 }
 
+static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
+                                                 Region &region,
+                                                 ValueRange useDeviceAddrVars,
+                                                 TypeRange useDeviceAddrTypes,
+                                                 ValueRange useDevicePtrVars,
+                                                 TypeRange useDevicePtrTypes) {
+  AllRegionPrintArgs args;
+  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
+  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
+  printBlockArgRegion(p, op, region, args);
+}
+
 /// Verifies Reduction Clause
 static LogicalResult
 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 769cdd57656b51..acec2c243e5161 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2449,8 +2449,8 @@ static void collectMapDataFromMapOperands(
     }
   };
 
-  addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
   addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
+  addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
 }
 
 static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
@@ -3056,6 +3056,34 @@ convertOmpTargetData(Operation *op, ll...
[truncated]

@skatrak
Copy link
Member Author

skatrak commented Sep 30, 2024

Ping for review!

Copy link
Contributor

@agozillon agozillon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @TIFitis would be a good second reviewer if he wishes to do so!

Copy link
Member

@TIFitis TIFitis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im happy with the changes. Thanks :)

@skatrak skatrak force-pushed the users/skatrak/block-args-02-asmfmt branch from a757969 to 181f7b6 Compare October 1, 2024 14:46
@skatrak skatrak force-pushed the users/skatrak/block-args-03-usedevice branch from f61e3a6 to c08a30e Compare October 1, 2024 15:01
Base automatically changed from users/skatrak/block-args-02-asmfmt to main October 1, 2024 15:18
This patch updates the `omp.target_data` operation to use the same formatting
as `map` clauses on `omp.target` for `use_device_addr` and `use_device_ptr`.
This is done so the mapping that is being enforced between op arguments and
associated entry block arguments is explicit.

The way it is achieved is by marking these clauses as entry block
argument-defining and adjusting printer/parsers accordingly.

As a result of this change, block arguments for `use_device_addr` come before
those for `use_device_ptr`, which is the opposite of the previous undocumented
situation. Some unit tests are updated based on this change, in addition to
those updated because of the format change.
@skatrak skatrak force-pushed the users/skatrak/block-args-03-usedevice branch from c08a30e to 56649e8 Compare October 1, 2024 15:45
@skatrak skatrak merged commit 5894d4e into main Oct 1, 2024
4 of 5 checks passed
@skatrak skatrak deleted the users/skatrak/block-args-03-usedevice branch October 1, 2024 15:46
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…m#109810)

This patch updates the `omp.target_data` operation to use the same
formatting as `map` clauses on `omp.target` for `use_device_addr` and
`use_device_ptr`. This is done so the mapping that is being enforced
between op arguments and associated entry block arguments is explicit.

The way it is achieved is by marking these clauses as entry block
argument-defining and adjusting printer/parsers accordingly.

As a result of this change, block arguments for `use_device_addr` come
before those for `use_device_ptr`, which is the opposite of the previous
undocumented situation. Some unit tests are updated based on this
change, in addition to those updated because of the format change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants