-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][openacc] Add private/reduction in legalize data pass #80882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-mlir-openacc Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations. Full diff: https://github.com/llvm/llvm-project/pull/80882.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9ca..db6b472ff9733a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
namespace {
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
- llvm::SmallVector<std::pair<Value, Value>> values;
- for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+ llvm::SmallVector<std::pair<Value, Value>> &values,
+ bool hostToDevice) {
+ for (auto operand : operands) {
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
values.push_back({accPtr, varPtr});
}
}
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+ llvm::SmallVector<std::pair<Value, Value>> values;
+
+ if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+ } else {
+ collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+ collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+ }
+ }
for (auto p : values)
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+ } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+ collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
}
});
}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a33..113fe90450ab7b 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
// CHECK: }
// CHECK: acc.yield
// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel {
+ acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel {
+// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
|
@llvm/pr-subscribers-mlir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations. Full diff: https://github.com/llvm/llvm-project/pull/80882.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9ca..db6b472ff9733a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
namespace {
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
- llvm::SmallVector<std::pair<Value, Value>> values;
- for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+ llvm::SmallVector<std::pair<Value, Value>> &values,
+ bool hostToDevice) {
+ for (auto operand : operands) {
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
values.push_back({accPtr, varPtr});
}
}
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+ llvm::SmallVector<std::pair<Value, Value>> values;
+
+ if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+ } else {
+ collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+ collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+ }
+ }
for (auto p : values)
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+ } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+ collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
}
});
}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a33..113fe90450ab7b 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
// CHECK: }
// CHECK: acc.yield
// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel {
+ acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel {
+// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks!
This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.