Skip to content

Commit 3633de7

Browse files
authored
[mlir][acc] Handle OpenACC host_data in LegalizeDataValues (#134767)
`LegalizeDataValuesInRegion` is intended to replace the SSA values used in a region with the output of data operations, but misses the handling of the OpenACC `host_data` construct. As a result, currently ``` !$acc host_data use_device(%var) ...%var... !$acc end host_data ``` is lowered to ``` %dev_var = acc.use_device(%var) acc.host_data data_operands(%dev_var) { ...%var... } ``` This pull request updates the LegalizeDataValuesInRegion to handle HostDataOp such that lowering results in ``` %dev_var = acc.use_device(%var) acc.host_data data_operands(%dev_var) { ...%dev_var... } ```
1 parent 7491ff7 commit 3633de7

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@
5858
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
5959
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
6060
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
61-
mlir::acc::DataOp, mlir::acc::DeclareOp
61+
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
6262
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
6363
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
64-
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
65-
mlir::acc::DeclareExitOp
64+
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
6665
#define ACC_DATA_CONSTRUCT_OPS \
6766
ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
6867
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \

mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
8181
collectVars(op.getDataClauseOperands(), values, hostToDevice);
8282
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
8383
!std::is_same_v<Op, acc::DataOp> &&
84-
!std::is_same_v<Op, acc::DeclareOp>) {
84+
!std::is_same_v<Op, acc::DeclareOp> &&
85+
!std::is_same_v<Op, acc::HostDataOp>) {
8586
collectVars(op.getReductionOperands(), values, hostToDevice);
8687
collectVars(op.getPrivateOperands(), values, hostToDevice);
8788
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
@@ -122,6 +123,8 @@ class LegalizeDataValuesInRegion
122123
collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123124
} else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124125
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
126+
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127+
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
125128
} else {
126129
llvm_unreachable("unsupported acc region op");
127130
}

mlir/test/Dialect/OpenACC/legalize-data.mlir

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func.func @test(%a: memref<10xf32>) {
102102
return
103103
}
104104

105-
// CHECK: func.func @test
105+
// CHECK-LABEL: func.func @test
106106
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
107107
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
108108
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
@@ -140,7 +140,7 @@ func.func @test(%a: memref<10xf32>) {
140140
return
141141
}
142142

143-
// CHECK: func.func @test
143+
// CHECK-LABEL: func.func @test
144144
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
145145
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
146146
// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
@@ -178,7 +178,7 @@ func.func @test(%a: memref<10xf32>) {
178178
return
179179
}
180180

181-
// CHECK: func.func @test
181+
// CHECK-LABEL: func.func @test
182182
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
183183
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
184184
// CHECK: acc.parallel {
@@ -216,7 +216,7 @@ func.func @test(%a: memref<10xf32>) {
216216
return
217217
}
218218

219-
// CHECK: func.func @test
219+
// CHECK-LABEL: func.func @test
220220
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
221221
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
222222
// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
@@ -226,3 +226,23 @@ func.func @test(%a: memref<10xf32>) {
226226
// CHECK: }
227227
// CHECK: acc.yield
228228
// CHECK: }
229+
230+
// -----
231+
232+
func.func @test(%a: memref<10xf32>) {
233+
%devptr = acc.use_device varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
234+
acc.host_data dataOperands(%devptr : memref<10xf32>) {
235+
func.call @foo(%a) : (memref<10xf32>) -> ()
236+
acc.terminator
237+
}
238+
return
239+
}
240+
func.func private @foo(memref<10xf32>)
241+
242+
// CHECK-LABEL: func.func @test
243+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
244+
// CHECK: %[[USE_DEVICE:.*]] = acc.use_device varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
245+
// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
246+
// DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
247+
// CHECK: acc.terminator
248+
// CHECK: }

0 commit comments

Comments
 (0)