Skip to content

Commit fa7d0d3

Browse files
committed
[mlir][openacc] Add legalize data pass for compute operation (#80351)
This patch adds a simple pass to replace the uses inside compute operation. It replaces the `varPtr` values with their corresponding `accPtr` values gathered through the dataClauseOperands. private and reductions variables are not included in this pass since they will normally be replace when they are materialized.
1 parent dd22140 commit fa7d0d3

File tree

12 files changed

+305
-20
lines changed

12 files changed

+305
-20
lines changed

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Affine/Passes.h"
2020
#include "mlir/Dialect/Complex/IR/Complex.h"
2121
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
22+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
2223
#include "mlir/InitAllDialects.h"
2324
#include "mlir/Pass/Pass.h"
2425
#include "mlir/Pass/PassRegistry.h"
@@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
7475
/// Register the standard passes we use. This comes from registerAllPasses(),
7576
/// but is a smaller set since we aren't using many of the passes found there.
7677
inline void registerMLIRPassesForFortranTools() {
78+
mlir::acc::registerOpenACCPasses();
7779
mlir::registerCanonicalizerPass();
7880
mlir::registerCSEPass();
7981
mlir::affine::registerAffineLoopFusionPass();
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
2+
3+
func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
4+
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
5+
%1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
6+
acc.parallel dataOperands(%1 : !fir.ref<i32>) {
7+
%c0_i32 = arith.constant 0 : i32
8+
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
9+
acc.yield
10+
}
11+
acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
12+
return
13+
}
14+
15+
// CHECK-LABEL: func.func @_QPsub1
16+
// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
17+
// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
18+
// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
19+
// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
20+
// CHECK: %c0_i32 = arith.constant 0 : i32
21+
// CHECK: hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32>
22+
// CHECK: acc.yield
23+
// CHECK: }
24+
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}

mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(Transforms)
2+
13
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
24
mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
35
add_public_tablegen_target(acc_common_td)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
3+
add_public_tablegen_target(MLIROpenACCPassIncGen)
4+
5+
add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- Passes.h - OpenACC Passes Construction and Registration ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
10+
#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
11+
12+
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
13+
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
14+
#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
15+
#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
16+
#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
#define GEN_PASS_DECL
20+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
21+
22+
namespace mlir {
23+
24+
namespace func {
25+
class FuncOp;
26+
} // namespace func
27+
28+
namespace acc {
29+
30+
/// Create a pass to replace ssa values in region with device/host values.
31+
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
32+
33+
/// Generate the code for registering conversion passes.
34+
#define GEN_PASS_REGISTRATION
35+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
36+
37+
} // namespace acc
38+
} // namespace mlir
39+
40+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
10+
#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
15+
let summary = "Legalize the data in the compute region";
16+
let description = [{
17+
This pass replace uses of varPtr in the compute region with their accPtr
18+
gathered from the data clause operands.
19+
}];
20+
let options = [
21+
Option<"hostToDevice", "host-to-device", "bool", "true",
22+
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
23+
"varPtr if false">
24+
];
25+
let constructor = "::mlir::acc::createLegalizeDataInRegion()";
26+
}
27+
28+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

mlir/include/mlir/InitAllPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
3535
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
3636
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
37+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
3738
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3839
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
3940
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -64,6 +65,7 @@ inline void registerAllPasses() {
6465
registerConversionPasses();
6566

6667
// Dialect passes
68+
acc::registerOpenACCPasses();
6769
affine::registerAffinePasses();
6870
amdgpu::registerAMDGPUPasses();
6971
registerAsyncPasses();
Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,2 @@
1-
add_mlir_dialect_library(MLIROpenACCDialect
2-
IR/OpenACC.cpp
3-
4-
ADDITIONAL_HEADER_DIRS
5-
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
6-
7-
DEPENDS
8-
MLIROpenACCOpsIncGen
9-
MLIROpenACCEnumsIncGen
10-
MLIROpenACCAttributesIncGen
11-
MLIROpenACCOpsInterfacesIncGen
12-
MLIROpenACCTypeInterfacesIncGen
13-
14-
LINK_LIBS PUBLIC
15-
MLIRIR
16-
MLIRLLVMDialect
17-
MLIRMemRefDialect
18-
MLIROpenACCMPCommon
19-
)
20-
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
add_mlir_dialect_library(MLIROpenACCDialect
2+
OpenACC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
6+
7+
DEPENDS
8+
MLIROpenACCOpsIncGen
9+
MLIROpenACCEnumsIncGen
10+
MLIROpenACCAttributesIncGen
11+
MLIROpenACCOpsInterfacesIncGen
12+
MLIROpenACCTypeInterfacesIncGen
13+
14+
LINK_LIBS PUBLIC
15+
MLIRIR
16+
MLIRLLVMDialect
17+
MLIRMemRefDialect
18+
MLIROpenACCMPCommon
19+
)
20+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
add_mlir_dialect_library(MLIROpenACCTransforms
2+
LegalizeData.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
6+
7+
DEPENDS
8+
MLIROpenACCPassIncGen
9+
MLIROpenACCOpsIncGen
10+
MLIROpenACCEnumsIncGen
11+
MLIROpenACCAttributesIncGen
12+
MLIROpenACCOpsInterfacesIncGen
13+
MLIROpenACCTypeInterfacesIncGen
14+
15+
LINK_LIBS PUBLIC
16+
MLIROpenACCDialect
17+
MLIRIR
18+
MLIRPass
19+
MLIRTransforms
20+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===- LegalizeData.cpp - -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/OpenACC/OpenACC.h"
13+
#include "mlir/Pass/Pass.h"
14+
#include "mlir/Transforms/RegionUtils.h"
15+
16+
namespace mlir {
17+
namespace acc {
18+
#define GEN_PASS_DEF_LEGALIZEDATAINREGION
19+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
20+
} // namespace acc
21+
} // namespace mlir
22+
23+
using namespace mlir;
24+
25+
namespace {
26+
27+
template <typename Op>
28+
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
29+
llvm::SmallVector<std::pair<Value, Value>> values;
30+
for (auto operand : op.getDataClauseOperands()) {
31+
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
32+
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
33+
if (varPtr && accPtr) {
34+
if (hostToDevice)
35+
values.push_back({varPtr, accPtr});
36+
else
37+
values.push_back({accPtr, varPtr});
38+
}
39+
}
40+
41+
for (auto p : values)
42+
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
43+
}
44+
45+
struct LegalizeDataInRegion
46+
: public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
47+
48+
void runOnOperation() override {
49+
func::FuncOp funcOp = getOperation();
50+
bool replaceHostVsDevice = this->hostToDevice.getValue();
51+
52+
funcOp.walk([&](Operation *op) {
53+
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
54+
return;
55+
56+
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
57+
collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
58+
} else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
59+
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
60+
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
61+
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
62+
}
63+
});
64+
}
65+
};
66+
67+
} // end anonymous namespace
68+
69+
std::unique_ptr<OperationPass<func::FuncOp>>
70+
mlir::acc::createLegalizeDataInRegion() {
71+
return std::make_unique<LegalizeDataInRegion>();
72+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
2+
// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
3+
4+
func.func @test(%a: memref<10xf32>, %i : index) {
5+
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
6+
acc.parallel dataOperands(%create : memref<10xf32>) {
7+
%ci = memref.load %a[%i] : memref<10xf32>
8+
acc.yield
9+
}
10+
return
11+
}
12+
13+
// CHECK-LABEL: func.func @test
14+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
15+
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
16+
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
17+
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
18+
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
19+
// CHECK: acc.yield
20+
// CHECK: }
21+
22+
// -----
23+
24+
func.func @test(%a: memref<10xf32>, %i : index) {
25+
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
26+
acc.serial dataOperands(%create : memref<10xf32>) {
27+
%ci = memref.load %a[%i] : memref<10xf32>
28+
acc.yield
29+
}
30+
return
31+
}
32+
33+
// CHECK-LABEL: func.func @test
34+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
35+
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
36+
// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) {
37+
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
38+
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
39+
// CHECK: acc.yield
40+
// CHECK: }
41+
42+
// -----
43+
44+
func.func @test(%a: memref<10xf32>, %i : index) {
45+
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
46+
acc.kernels dataOperands(%create : memref<10xf32>) {
47+
%ci = memref.load %a[%i] : memref<10xf32>
48+
acc.terminator
49+
}
50+
return
51+
}
52+
53+
// CHECK-LABEL: func.func @test
54+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
55+
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
56+
// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) {
57+
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
58+
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
59+
// CHECK: acc.terminator
60+
// CHECK: }
61+
62+
// -----
63+
64+
func.func @test(%a: memref<10xf32>) {
65+
%lb = arith.constant 0 : index
66+
%st = arith.constant 1 : index
67+
%c10 = arith.constant 10 : index
68+
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
69+
acc.parallel dataOperands(%create : memref<10xf32>) {
70+
acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
71+
%ci = memref.load %a[%i] : memref<10xf32>
72+
acc.yield
73+
}
74+
acc.yield
75+
}
76+
return
77+
}
78+
79+
// CHECK: func.func @test
80+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
81+
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
82+
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
83+
// CHECK: acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
84+
// DEVICE: %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32>
85+
// CHECK: acc.yield
86+
// CHECK: }
87+
// CHECK: acc.yield
88+
// CHECK: }

0 commit comments

Comments
 (0)