diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index e56b26a243a44..aef9352e70ea3 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1123,6 +1123,8 @@ addReductionDecl(mlir::Location currentLocation, Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); @@ -1160,6 +1162,8 @@ addReductionDecl(mlir::Location currentLocation, Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); @@ -3746,6 +3750,8 @@ void Fortran::lower::genOpenMPReduction( Fortran::parser::Unwrap(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); mlir::Type reductionType = reductionVal.getType().cast().getEleTy(); if (!reductionType.isa()) { @@ -3789,6 +3795,9 @@ void Fortran::lower::genOpenMPReduction( ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = + reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); for (const mlir::OpOperand &reductionValUse : reductionVal.getUses()) { if (auto loadOp = mlir::dyn_cast( @@ -3844,6 +3853,13 @@ mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal, return reductionOp; } } + if (auto assign = + mlir::dyn_cast(reductionOperand.getOwner())) { + if (assign.getLhs() == *reductionVal) { + assign.erase(); + return reductionOp; + } + } } } } @@ -3899,6 +3915,11 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp, if (storeOp.getMemref() == symVal) storeOp.erase(); } + if (auto assignOp = + mlir::dyn_cast(convertReductionUse)) { + if (assignOp.getLhs() == symVal) + assignOp.erase(); + } } } } diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 new file mode 100644 index 0000000000000..97ee665442e3a --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 @@ -0,0 +1,43 @@ +! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +!CHECK-LABEL: omp.reduction.declare +!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[C0_1]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +!CHECK: omp.yield(%[[RES]] : i32) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_int_reduction +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_int_reductionEx"} +!CHECK: %[[XDECL:.*]]:2 = hlfir.declare %[[XREF]] {uniq_name = "_QFsimple_int_reductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C0_2:.*]] = arith.constant 0 : i32 +!CHECK: hlfir.assign %[[C0_2]] to %[[XDECL]]#0 : i32, !fir.ref +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFsimple_int_reductionEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C100:.*]] = arith.constant 100 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_I32_NAME]] -> %[[XDECL]]#0 : !fir.ref) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]]) +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_DECL]]#1 : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_DECL]]#0 : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL]], %[[XDECL]]#0 : i32, !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return +subroutine simple_int_reduction + integer :: x + x = 0 + !$omp parallel + !$omp do reduction(+:x) + do i=1, 100 + x = x + i + end do + !$omp end do + !$omp end parallel +end subroutine diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 new file mode 100644 index 0000000000000..0c5d99226600b --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 @@ -0,0 +1,36 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK: omp.reduction.declare @[[MAX_DECLARE_I:.*]] : i32 init { +!CHECK: %[[MINIMUM_VAL_I:.*]] = arith.constant -2147483648 : i32 +!CHECK: omp.yield(%[[MINIMUM_VAL_I]] : i32) +!CHECK: combiner +!CHECK: ^bb0(%[[ARG0_I:.*]]: i32, %[[ARG1_I:.*]]: i32): +!CHECK: %[[COMB_VAL_I:.*]] = arith.maxsi %[[ARG0_I]], %[[ARG1_I]] : i32 +!CHECK: omp.yield(%[[COMB_VAL_I]] : i32) + +!CHECK-LABEL: @_QPreduction_max_int +!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box> +!CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFreduction_max_intEx"} +!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFreduction_max_intEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_BOX]] {uniq_name = "_QFreduction_max_intEy"} : (!fir.box>) -> (!fir.box>, !fir.box>) +!CHECK: omp.parallel +!CHECK: omp.wsloop reduction(@[[MAX_DECLARE_I]] -> %[[X_DECL]]#0 : !fir.ref) for +!CHECK: %[[Y_I_REF:.*]] = hlfir.designate %[[Y_DECL]]#0 ({{.*}}) : (!fir.box>, i64) -> !fir.ref +!CHECK: %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref +!CHECK: omp.reduction %[[Y_I]], %[[X_DECL]]#0 : i32, !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator + +subroutine reduction_max_int(y) + integer :: x, y(:) + x = 0 + !$omp parallel + !$omp do reduction(max:x) + do i=1, 100 + x = max(x, y(i)) + end do + !$omp end do + !$omp end parallel + print *, x +end subroutine