Skip to content

[OpenMP][Flang][MLIR] Lowering of requires directive from MLIR to LLV… #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
return val;
}

/// if a value comes from a fir.rebox, follow the rebox to the original source,
/// of the value, otherwise return the value
static mlir::Value unwrapReboxOp(mlir::Value val) {
// don't support reboxes of reboxes
if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
val = rebox.getBox();
return val;
}

/// normalize a value (removing fir.declare and fir.rebox) so that we can
/// more conveniently spot values which came from function arguments
static mlir::Value normaliseVal(mlir::Value val) {
return unwrapFirDeclare(unwrapReboxOp(val));
}

void LoopVersioningPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::func::FuncOp func = getOperation();
Expand All @@ -112,7 +127,7 @@ void LoopVersioningPass::runOnOperation() {
/// A structure to hold an argument, the size of the argument and dimension
/// information.
struct ArgInfo {
mlir::Value *arg;
mlir::Value arg;
size_t size;
unsigned rank;
fir::BoxDimsOp dims[CFI_MAX_RANK];
Expand All @@ -138,7 +153,7 @@ void LoopVersioningPass::runOnOperation() {
else if (auto cty = elementType.dyn_cast<fir::ComplexType>())
typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8;
if (typeSize)
argsOfInterest.push_back({&arg, typeSize, rank, {}});
argsOfInterest.push_back({arg, typeSize, rank, {}});
else
LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
}
Expand Down Expand Up @@ -166,7 +181,9 @@ void LoopVersioningPass::runOnOperation() {
return;
mlir::Value operand = op->getOperand(0);
for (auto a : argsOfInterest) {
if (*a.arg == unwrapFirDeclare(operand)) {
if (a.arg == normaliseVal(operand)) {
// use the reboxed value, not the block arg when re-creating the loop:
a.arg = operand;
// Only add if it's not already in the list.
if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
return it.arg == a.arg;
Expand Down Expand Up @@ -211,7 +228,7 @@ void LoopVersioningPass::runOnOperation() {
for (unsigned i = 0; i < ndims; i++) {
mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
arg.dims[i] = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
*arg.arg, dimIdx);
arg.arg, dimIdx);
}
// We only care about lowest order dimension, here.
mlir::Value elemSize =
Expand All @@ -238,11 +255,11 @@ void LoopVersioningPass::runOnOperation() {
for (auto &arg : op.argsAndDims) {
fir::SequenceType::Shape newShape;
newShape.push_back(fir::SequenceType::getUnknownExtent());
auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg->getType());
auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg.getType());
mlir::Type arrTy = fir::SequenceType::get(newShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Type refArrTy = builder.getRefType(arrTy);
auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, *arg.arg);
auto carg = builder.create<fir::ConvertOp>(loc, boxArrTy, arg.arg);
auto caddr = builder.create<fir::BoxAddrOp>(loc, refArrTy, carg);
auto insPt = builder.saveInsertionPoint();
// Use caddr instead of arg.
Expand All @@ -254,8 +271,7 @@ void LoopVersioningPass::runOnOperation() {
// arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
// where stride is the distance between elements in the dimensions
// 0, 1 and 2 or x, y and z.
if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
coop->getOperands().size() >= 2) {
if (coop->getOperand(0) == arg.arg && coop->getOperands().size() >= 2) {
builder.setInsertionPoint(coop);
mlir::Value totalIndex;
for (unsigned i = arg.rank - 1; i > 0; i--) {
Expand Down
10 changes: 6 additions & 4 deletions flang/test/Transforms/loop-versioning.fir
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
module {
func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
%decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
%rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
%1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
%cst = arith.constant 0.000000e+00 : f64
Expand All @@ -31,7 +32,7 @@ module {
%9 = fir.convert %8 : (i32) -> i64
%c1_i64 = arith.constant 1 : i64
%10 = arith.subi %9, %c1_i64 : i64
%11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
%11 = fir.coordinate_of %rebox, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
%12 = fir.load %11 : !fir.ref<f64>
%13 = arith.addf %7, %12 fastmath<contract> : f64
fir.store %13 to %1 : !fir.ref<f64>
Expand All @@ -49,12 +50,13 @@ module {
// CHECK-LABEL: func.func @sum1d(
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
// CHECK: %[[REBOX:.*]] = fir.rebox %[[DECL]]
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[REBOX]], %[[ZERO]] : {{.*}}
// CHECK: %[[SIZE:.*]] = arith.constant 8 : index
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS]]#2, %[[SIZE]]
// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}}
// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]]
// CHECK: %[[NEWARR:.*]] = fir.convert %[[REBOX]]
// CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref<!fir.array<?xf64>>
// CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}}
// CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %{{.*}} : (!fir.ref<!fir.array<?xf64>>, index) -> !fir.ref<f64>
Expand All @@ -64,7 +66,7 @@ module {
// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
// CHECK: } else {
// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[REBOX]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
// CHECK: fir.result %{{.*}}, %{{.*}}
// CHECK: }
Expand Down