Skip to content

[mlir][openacc] Add legalize data pass for compute operation #80351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Support/InitFIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
Expand Down Expand Up @@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
/// Register the standard passes we use. This comes from registerAllPasses(),
/// but is a smaller set since we aren't using many of the passes found there.
inline void registerMLIRPassesForFortranTools() {
mlir::acc::registerOpenACCPasses();
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
mlir::affine::registerAffineLoopFusionPass();
Expand Down
24 changes: 24 additions & 0 deletions flang/test/Fir/OpenACC/legalize-data.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s

func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
acc.parallel dataOperands(%1 : !fir.ref<i32>) {
%c0_i32 = arith.constant 0 : i32
hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
acc.yield
}
acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
return
}

// CHECK-LABEL: func.func @_QPsub1
// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
// CHECK: %c0_i32 = arith.constant 0 : i32
// CHECK: hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32>
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(Transforms)

set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
add_public_tablegen_target(acc_common_td)
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
add_public_tablegen_target(MLIROpenACCPassIncGen)

add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===- Passes.h - OpenACC Passes Construction and Registration ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H

#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
#include "mlir/Pass/Pass.h"

#define GEN_PASS_DECL
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"

namespace mlir {

namespace func {
class FuncOp;
} // namespace func

namespace acc {

/// Create a pass to replace ssa values in region with device/host values.
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"

} // namespace acc
} // namespace mlir

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
let summary = "Legalize the data in the compute region";
let description = [{
This pass replaces uses of varPtr in the compute region with their accPtr
gathered from the data clause operands.
}];
let options = [
Option<"hostToDevice", "host-to-device", "bool", "true",
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
"varPtr if false">
];
let constructor = "::mlir::acc::createLegalizeDataInRegion()";
}

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
Expand Down Expand Up @@ -64,6 +65,7 @@ inline void registerAllPasses() {
registerConversionPasses();

// Dialect passes
acc::registerOpenACCPasses();
affine::registerAffinePasses();
amdgpu::registerAMDGPUPasses();
registerAsyncPasses();
Expand Down
22 changes: 2 additions & 20 deletions mlir/lib/Dialect/OpenACC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,2 @@
add_mlir_dialect_library(MLIROpenACCDialect
IR/OpenACC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC

DEPENDS
MLIROpenACCOpsIncGen
MLIROpenACCEnumsIncGen
MLIROpenACCAttributesIncGen
MLIROpenACCOpsInterfacesIncGen
MLIROpenACCTypeInterfacesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRMemRefDialect
MLIROpenACCMPCommon
)

add_subdirectory(IR)
add_subdirectory(Transforms)
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
add_mlir_dialect_library(MLIROpenACCDialect
OpenACC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC

DEPENDS
MLIROpenACCOpsIncGen
MLIROpenACCEnumsIncGen
MLIROpenACCAttributesIncGen
MLIROpenACCOpsInterfacesIncGen
MLIROpenACCTypeInterfacesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRMemRefDialect
MLIROpenACCMPCommon
)

17 changes: 17 additions & 0 deletions mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIROpenACCTransforms
LegalizeData.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC

DEPENDS
MLIROpenACCPassIncGen
MLIROpenACCOpsIncGen
MLIROpenACCEnumsIncGen
MLIROpenACCAttributesIncGen
MLIROpenACCOpsInterfacesIncGen
MLIROpenACCTypeInterfacesIncGen

LINK_LIBS PUBLIC
MLIROpenACCDialect
)
72 changes: 72 additions & 0 deletions mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- LegalizeData.cpp - -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/OpenACC/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir {
namespace acc {
#define GEN_PASS_DEF_LEGALIZEDATAINREGION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir

using namespace mlir;

namespace {

template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
for (auto operand : op.getDataClauseOperands()) {
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
if (varPtr && accPtr) {
if (hostToDevice)
values.push_back({varPtr, accPtr});
else
values.push_back({accPtr, varPtr});
}
}

for (auto p : values)
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
}

struct LegalizeDataInRegion
: public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {

void runOnOperation() override {
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();

funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
return;

if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
} else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
}
});
}
};

} // end anonymous namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::acc::createLegalizeDataInRegion() {
return std::make_unique<LegalizeDataInRegion>();
}
88 changes: 88 additions & 0 deletions mlir/test/Dialect/OpenACC/legalize-data.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST

func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.parallel dataOperands(%create : memref<10xf32>) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
}
return
}

// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
// CHECK: }

// -----

func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.serial dataOperands(%create : memref<10xf32>) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
}
return
}

// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) {
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
// CHECK: }

// -----

func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.kernels dataOperands(%create : memref<10xf32>) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.terminator
}
return
}

// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) {
// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
// CHECK: acc.terminator
// CHECK: }

// -----

func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
%c10 = arith.constant 10 : index
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.parallel dataOperands(%create : memref<10xf32>) {
acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
%ci = memref.load %a[%i] : memref<10xf32>
acc.yield
}
acc.yield
}
return
}

// CHECK: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
// CHECK: acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
// DEVICE: %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32>
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.yield
// CHECK: }