-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause #94718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: None (khaki3) ChangesDerived from #92480. This PR updates the lowering process of DO CONCURRENT to support F'2023 REDUCE clause. The structure Full diff: https://github.com/llvm/llvm-project/pull/94718.diff 2 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 512c7a349ae21..d0a0a36500f61 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
bool hasLocalitySpecs() const {
return !localSymList.empty() || !localInitSymList.empty() ||
- !sharedSymList.empty();
+ !reduceSymList.empty() || !sharedSymList.empty();
}
// Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
bool isUnordered; // do concurrent, forall
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
+ llvm::SmallVector<
+ std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
+ reduceSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1741,6 +1744,36 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->create<fir::UnreachableOp>(loc);
}
+ fir::ReduceOperationEnum
+ getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
+ switch (rOpr.v) {
+ case Fortran::parser::ReductionOperator::Operator::Plus:
+ return fir::ReduceOperationEnum::Add;
+ case Fortran::parser::ReductionOperator::Operator::Multiply:
+ return fir::ReduceOperationEnum::Multiply;
+ case Fortran::parser::ReductionOperator::Operator::And:
+ return fir::ReduceOperationEnum::AND;
+ case Fortran::parser::ReductionOperator::Operator::Or:
+ return fir::ReduceOperationEnum::OR;
+ case Fortran::parser::ReductionOperator::Operator::Eqv:
+ return fir::ReduceOperationEnum::EQV;
+ case Fortran::parser::ReductionOperator::Operator::Neqv:
+ return fir::ReduceOperationEnum::NEQV;
+ case Fortran::parser::ReductionOperator::Operator::Max:
+ return fir::ReduceOperationEnum::MAX;
+ case Fortran::parser::ReductionOperator::Operator::Min:
+ return fir::ReduceOperationEnum::MIN;
+ case Fortran::parser::ReductionOperator::Operator::Iand:
+ return fir::ReduceOperationEnum::IAND;
+ case Fortran::parser::ReductionOperator::Operator::Ior:
+ return fir::ReduceOperationEnum::IOR;
+ case Fortran::parser::ReductionOperator::Operator::Ieor:
+ return fir::ReduceOperationEnum::EIOR;
+ }
+ fir::emitFatalError(toLocation(), "illegal reduction operator");
+ return fir::ReduceOperationEnum::Add;
+ }
+
/// Collect DO CONCURRENT or FORALL loop control information.
IncrementLoopNestInfo getConcurrentControl(
const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1796,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
for (const Fortran::parser::Name &x : localInitList->v)
info.localInitSymList.push_back(x.symbol);
+ if (const auto *reduceList =
+ std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
+ fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
+ std::get<Fortran::parser::ReductionOperator>(reduceList->t));
+ for (const Fortran::parser::Name &x :
+ std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
+ info.reduceSymList.push_back(
+ std::make_pair(reduce_operation, x.symbol));
+ }
+ }
if (const auto *sharedList =
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1998,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Type loopVarType = info.getLoopVariableType();
mlir::Value loopValue;
if (info.isUnordered) {
+ llvm::SmallVector<mlir::Value> reduceOperands;
+ llvm::SmallVector<mlir::Attribute> reduceAttrs;
+ // Create DO CONCURRENT reduce operations and attributes
+ for (const auto reduceSym : info.reduceSymList) {
+ const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
+ const Fortran::semantics::Symbol *sym = reduceSym.second;
+ fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+ reduceOperands.push_back(fir::getBase(exv));
+ auto reduce_attr =
+ fir::ReduceAttr::get(builder->getContext(), reduce_operation);
+ reduceAttrs.push_back(reduce_attr);
+ }
// The loop variable value is explicitly updated.
info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
+ loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
+ /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
+ llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
builder->setInsertionPointToStart(info.doLoop.getBody());
loopValue = builder->createConvert(loc, loopVarType,
info.doLoop.getInductionVar());
diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
new file mode 100644
index 0000000000000..dd24e26d72c31
--- /dev/null
+++ b/flang/test/Lower/loops3.f90
@@ -0,0 +1,23 @@
+! Test do concurrent reduction
+! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+
+! CHECK-LABEL: loop_test
+subroutine loop_test
+ integer(4) :: i, j, k, tmp, sum = 0
+ real :: m
+
+ i = 100
+ j = 200
+ k = 300
+
+ ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"
+ ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
+ ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+ do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
+ tmp = i + j + k
+ sum = tmp + sum
+ m = max(m, sum)
+ enddo
+end subroutine loop_test
|
As DO CONCURRENT is not necessarily concurrent, it is sufficient to put a reduce clause only on the innermost loop for propagating reduction semantics. The scope of reduction is analyzable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…lvm#94718) Derived from llvm#92480. This PR updates the lowering process of DO CONCURRENT to support F'2023 REDUCE clause. The structure `IncrementLoopInfo` is extended to have both reduction operations and symbols in `reduceSymList`. The function `getConcurrentControl` constructs `reduceSymList` for the innermost loop. Finally, `genFIRIncrementLoopBegin` builds `fir.do_loop` with reduction operands.
Derived from #92480. This PR updates the lowering process of DO CONCURRENT to support F'2023 REDUCE clause. The structure
IncrementLoopInfo
is extended to have both reduction operations and symbols inreduceSymList
. The functiongetConcurrentControl
constructsreduceSymList
for the innermost loop. Finally,genFIRIncrementLoopBegin
buildsfir.do_loop
with reduction operands.