diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 512c7a349ae21..14e99757925ac 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -104,7 +104,7 @@ struct IncrementLoopInfo { bool hasLocalitySpecs() const { return !localSymList.empty() || !localInitSymList.empty() || - !sharedSymList.empty(); + !reduceSymList.empty() || !sharedSymList.empty(); } // Data members common to both structured and unstructured loops. @@ -116,6 +116,9 @@ struct IncrementLoopInfo { bool isUnordered; // do concurrent, forall llvm::SmallVector localSymList; llvm::SmallVector localInitSymList; + llvm::SmallVector< + std::pair> + reduceSymList; llvm::SmallVector sharedSymList; mlir::Value loopVariable = nullptr; @@ -1741,6 +1744,35 @@ class FirConverter : public Fortran::lower::AbstractConverter { builder->create(loc); } + fir::ReduceOperationEnum + getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) { + switch (rOpr.v) { + case Fortran::parser::ReductionOperator::Operator::Plus: + return fir::ReduceOperationEnum::Add; + case Fortran::parser::ReductionOperator::Operator::Multiply: + return fir::ReduceOperationEnum::Multiply; + case Fortran::parser::ReductionOperator::Operator::And: + return fir::ReduceOperationEnum::AND; + case Fortran::parser::ReductionOperator::Operator::Or: + return fir::ReduceOperationEnum::OR; + case Fortran::parser::ReductionOperator::Operator::Eqv: + return fir::ReduceOperationEnum::EQV; + case Fortran::parser::ReductionOperator::Operator::Neqv: + return fir::ReduceOperationEnum::NEQV; + case Fortran::parser::ReductionOperator::Operator::Max: + return fir::ReduceOperationEnum::MAX; + case Fortran::parser::ReductionOperator::Operator::Min: + return fir::ReduceOperationEnum::MIN; + case Fortran::parser::ReductionOperator::Operator::Iand: + return fir::ReduceOperationEnum::IAND; + case Fortran::parser::ReductionOperator::Operator::Ior: + return fir::ReduceOperationEnum::IOR; + case Fortran::parser::ReductionOperator::Operator::Ieor: + return fir::ReduceOperationEnum::EIOR; + } + llvm_unreachable("illegal reduction operator"); + } + /// Collect DO CONCURRENT or FORALL loop control information. IncrementLoopNestInfo getConcurrentControl( const Fortran::parser::ConcurrentHeader &header, @@ -1763,6 +1795,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { std::get_if(&x.u)) for (const Fortran::parser::Name &x : localInitList->v) info.localInitSymList.push_back(x.symbol); + if (const auto *reduceList = + std::get_if(&x.u)) { + fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum( + std::get(reduceList->t)); + for (const Fortran::parser::Name &x : + std::get>(reduceList->t)) { + info.reduceSymList.push_back( + std::make_pair(reduce_operation, x.symbol)); + } + } if (const auto *sharedList = std::get_if(&x.u)) for (const Fortran::parser::Name &x : sharedList->v) @@ -1955,9 +1997,23 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::Type loopVarType = info.getLoopVariableType(); mlir::Value loopValue; if (info.isUnordered) { + llvm::SmallVector reduceOperands; + llvm::SmallVector reduceAttrs; + // Create DO CONCURRENT reduce operands and attributes + for (const auto reduceSym : info.reduceSymList) { + const fir::ReduceOperationEnum reduce_operation = reduceSym.first; + const Fortran::semantics::Symbol *sym = reduceSym.second; + fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr); + reduceOperands.push_back(fir::getBase(exv)); + auto reduce_attr = + fir::ReduceAttr::get(builder->getContext(), reduce_operation); + reduceAttrs.push_back(reduce_attr); + } // The loop variable value is explicitly updated. info.doLoop = builder->create( - loc, lowerValue, upperValue, stepValue, /*unordered=*/true); + loc, lowerValue, upperValue, stepValue, /*unordered=*/true, + /*finalCountValue=*/false, /*iterArgs=*/std::nullopt, + llvm::ArrayRef(reduceOperands), reduceAttrs); builder->setInsertionPointToStart(info.doLoop.getBody()); loopValue = builder->createConvert(loc, loopVarType, info.doLoop.getInductionVar()); diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90 new file mode 100644 index 0000000000000..2e62ee480ec8a --- /dev/null +++ b/flang/test/Lower/loops3.f90 @@ -0,0 +1,23 @@ +! Test do concurrent reduction +! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s + +! CHECK-LABEL: loop_test +subroutine loop_test + integer(4) :: i, j, k, tmp, sum = 0 + real :: m + + i = 100 + j = 200 + k = 300 + + ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"} + ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr -> %[[VAL_1:.*]] : !fir.ref, #fir.reduce_attr -> %[[VAL_0:.*]] : !fir.ref) { + do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m) + tmp = i + j + k + sum = tmp + sum + m = max(m, sum) + enddo +end subroutine loop_test