Skip to content

[mlir][openacc] Add private/reduction in legalize data pass #80882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2024

Conversation

clementval
Copy link
Contributor

This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.


Full diff: https://github.com/llvm/llvm-project/pull/80882.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp (+24-5)
  • (modified) mlir/test/Dialect/OpenACC/legalize-data.mlir (+114)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9ca..db6b472ff9733a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
-  llvm::SmallVector<std::pair<Value, Value>> values;
-  for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+                        llvm::SmallVector<std::pair<Value, Value>> &values,
+                        bool hostToDevice) {
+  for (auto operand : operands) {
     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
     if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
         values.push_back({accPtr, varPtr});
     }
   }
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+
+  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+    collectPtrs(op.getReductionOperands(), values, hostToDevice);
+    collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+  } else {
+    collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+    if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+      collectPtrs(op.getReductionOperands(), values, hostToDevice);
+      collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+    }
+  }
 
   for (auto p : values)
     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     funcOp.walk([&](Operation *op) {
-      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+        collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
       }
     });
   }
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a33..113fe90450ab7b 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
 // CHECK:   }
 // CHECK:   acc.yield
 // CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel {
+    acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel  {
+// CHECK:   acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.


Full diff: https://github.com/llvm/llvm-project/pull/80882.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp (+24-5)
  • (modified) mlir/test/Dialect/OpenACC/legalize-data.mlir (+114)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9ca..db6b472ff9733a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
-  llvm::SmallVector<std::pair<Value, Value>> values;
-  for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+                        llvm::SmallVector<std::pair<Value, Value>> &values,
+                        bool hostToDevice) {
+  for (auto operand : operands) {
     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
     if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
         values.push_back({accPtr, varPtr});
     }
   }
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+
+  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+    collectPtrs(op.getReductionOperands(), values, hostToDevice);
+    collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+  } else {
+    collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+    if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+      collectPtrs(op.getReductionOperands(), values, hostToDevice);
+      collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+    }
+  }
 
   for (auto p : values)
     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     funcOp.walk([&](Operation *op) {
-      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+        collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
       }
     });
   }
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a33..113fe90450ab7b 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
 // CHECK:   }
 // CHECK:   acc.yield
 // CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel {
+    acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel  {
+// CHECK:   acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks!

@clementval clementval merged commit 4c9717c into llvm:main Feb 6, 2024
@clementval clementval deleted the acc_legalize_data2 branch February 6, 2024 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants