Skip to content

Allow do concurrent inside cuf kernel directive #127693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 20, 2025
Merged

Conversation

wangzpgi
Copy link
Contributor

Allow do concurrent inside cuf kernel directive to avoid the following Lowering error:

void {anonymous}::FirConverter::genFIR(const Fortran::parser::CUFKernelDoConstruct&): Assertion `bounds && "Expected bounds on the loop construct"' failed.

@wangzpgi wangzpgi self-assigned this Feb 18, 2025
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Zhen Wang (wangzpgi)

Changes

Allow do concurrent inside cuf kernel directive to avoid the following Lowering error:

void {anonymous}::FirConverter::genFIR(const Fortran::parser::CUFKernelDoConstruct&): Assertion `bounds && "Expected bounds on the loop construct"' failed.

Full diff: https://github.com/llvm/llvm-project/pull/127693.diff

2 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+125-40)
  • (added) flang/test/Lower/CUDA/cuda-doconc.cuf (+20)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 36e58e456dea3..61dd9f0797fc9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3074,50 +3074,135 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     llvm::SmallVector<mlir::Value> ivValues;
     Fortran::lower::pft::Evaluation *loopEval =
         &getEval().getFirstNestedEvaluation();
-    for (unsigned i = 0; i < nestedLoops; ++i) {
-      const Fortran::parser::LoopControl *loopControl;
-      mlir::Location crtLoc = loc;
-      if (i == 0) {
-        loopControl = &*outerDoConstruct->GetLoopControl();
-        crtLoc =
-            genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
-      } else {
-        auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
-        assert(doCons && "expect do construct");
-        loopControl = &*doCons->GetLoopControl();
-        crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+    bool isDoConcurrent = outerDoConstruct->IsDoConcurrent();
+    if (isDoConcurrent) {
+      // Handle DO CONCURRENT
+      locs.push_back(
+          genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
+      const Fortran::parser::LoopControl *loopControl =
+          &*outerDoConstruct->GetLoopControl();
+      const auto &concurrent =
+          std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
+
+      if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
+               .empty())
+        TODO(loc, "DO CONCURRENT with locality spec");
+
+      const auto &concurrentHeader =
+          std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
+      const auto &controls =
+          std::get<std::list<Fortran::parser::ConcurrentControl>>(
+              concurrentHeader.t);
+
+      for (const auto &control : controls) {
+        auto lb = fir::getBase(genExprValue(
+            *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx));
+        auto ub = fir::getBase(genExprValue(
+            *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx));
+        mlir::Value step;
+
+        if (const auto &expr =
+                std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+                    control.t)) {
+          step = fir::getBase(
+              genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
+        } else {
+          step = builder->create<mlir::arith::ConstantIndexOp>(
+              loc, 1); // Use index type directly
+        }
+
+        // Ensure lb, ub, and step are of index type using fir.convert
+        auto indexType = builder->getIndexType();
+        if (lb.getType() != indexType) {
+          lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
+        }
+        if (ub.getType() != indexType) {
+          ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
+        }
+        if (step.getType() != indexType) {
+          step = builder->create<fir::ConvertOp>(loc, indexType, step);
+        }
+
+        lbs.push_back(lb);
+        ubs.push_back(ub);
+        steps.push_back(step);
+
+        const auto &name = std::get<Fortran::parser::Name>(control.t);
+
+        // Handle induction variable
+        mlir::Value ivValue = getSymbolAddress(*name.symbol);
+        std::size_t ivTypeSize = name.symbol->size();
+        if (ivTypeSize == 0)
+          llvm::report_fatal_error("unexpected induction variable size");
+        mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);
+
+        if (!ivValue) {
+          // DO CONCURRENT induction variables are not mapped yet since they are
+          // local to the DO CONCURRENT scope.
+          mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+          builder->setInsertionPointToStart(builder->getAllocaBlock());
+          ivValue = builder->createTemporaryAlloc(
+              loc, ivTy, toStringRef(name.symbol->name()));
+          builder->restoreInsertionPoint(insPt);
+        }
+
+        // Create the hlfir.declare operation using the symbol's name
+        auto declareOp = builder->create<hlfir::DeclareOp>(
+            loc, ivValue, toStringRef(name.symbol->name()));
+        ivValue = declareOp.getResult(0);
+
+        // Bind the symbol to the declared variable
+        bindSymbol(*name.symbol, ivValue);
+        ivValues.push_back(ivValue);
+        ivTypes.push_back(ivTy);
+        ivLocs.push_back(loc);
       }
+    } else {
+      for (unsigned i = 0; i < nestedLoops; ++i) {
+        const Fortran::parser::LoopControl *loopControl;
+        mlir::Location crtLoc = loc;
+        if (i == 0) {
+          loopControl = &*outerDoConstruct->GetLoopControl();
+          crtLoc = genLocation(
+              Fortran::parser::FindSourceLocation(outerDoConstruct));
+        } else {
+          auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+          assert(doCons && "expect do construct");
+          loopControl = &*doCons->GetLoopControl();
+          crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+        }
+
+        locs.push_back(crtLoc);
 
-      locs.push_back(crtLoc);
-
-      const Fortran::parser::LoopControl::Bounds *bounds =
-          std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
-      assert(bounds && "Expected bounds on the loop construct");
-
-      Fortran::semantics::Symbol &ivSym =
-          bounds->name.thing.symbol->GetUltimate();
-      ivValues.push_back(getSymbolAddress(ivSym));
-
-      lbs.push_back(builder->createConvert(
-          crtLoc, idxTy,
-          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
-                                    stmtCtx))));
-      ubs.push_back(builder->createConvert(
-          crtLoc, idxTy,
-          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
-                                    stmtCtx))));
-      if (bounds->step)
-        steps.push_back(builder->createConvert(
+        const Fortran::parser::LoopControl::Bounds *bounds =
+            std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+        assert(bounds && "Expected bounds on the loop construct");
+
+        Fortran::semantics::Symbol &ivSym =
+            bounds->name.thing.symbol->GetUltimate();
+        ivValues.push_back(getSymbolAddress(ivSym));
+
+        lbs.push_back(builder->createConvert(
             crtLoc, idxTy,
             fir::getBase(genExprValue(
-                *Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
-      else // If `step` is not present, assume it is `1`.
-        steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
-
-      ivTypes.push_back(idxTy);
-      ivLocs.push_back(crtLoc);
-      if (i < nestedLoops - 1)
-        loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+                *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
+        ubs.push_back(builder->createConvert(
+            crtLoc, idxTy,
+            fir::getBase(genExprValue(
+                *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+        if (bounds->step)
+          steps.push_back(builder->createConvert(
+              crtLoc, idxTy,
+              fir::getBase(genExprValue(
+                  *Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
+        else // If `step` is not present, assume it is `1`.
+          steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+
+        ivTypes.push_back(idxTy);
+        ivLocs.push_back(crtLoc);
+        if (i < nestedLoops - 1)
+          loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+      }
     }
 
     auto op = builder->create<cuf::KernelOp>(
diff --git a/flang/test/Lower/CUDA/cuda-doconc.cuf b/flang/test/Lower/CUDA/cuda-doconc.cuf
new file mode 100644
index 0000000000000..e11688f4fe960
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-doconc.cuf
@@ -0,0 +1,20 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Check if do concurrent works inside cuf kernel directive
+
+program main
+  integer :: i, n
+  integer, managed :: a(3)
+  a(:) = -1
+  n = 3
+  n = n - 1
+  !$cuf kernel do
+  do concurrent(i=1:n)
+    a(i) = 1
+  end do
+end
+
+! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.kernel<<<*, *>>>
+! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>

@clementval
Copy link
Contributor

clementval commented Feb 19, 2025 via email

@wangzpgi
Copy link
Contributor Author

You are right, I will add a multi range example.

@wangzpgi wangzpgi requested a review from clementval February 19, 2025 04:47
@wangzpgi wangzpgi requested a review from clementval February 19, 2025 17:07
! CHECK: func.func @_QPdoconc2() {
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the semantic of !$cuf kernel do? For me it means !$cuf kernel do(1) so only the first range is part of the cuf kernel operation and the rest should be nested inside. Let me know if the semantic is different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the semantic of !$cuf kernel do? For me it means !$cuf kernel do(1) so only the first range is part of the cuf kernel operation and the rest should be nested inside. Let me know if the semantic is different.

Yes, it's the same semantic with do and do(1), we have this multi range with cuf.kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one range should be on the cuf.kernel. The rest of the ranges should be nested inside the op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one range should be on the cuf.kernel. The rest of the ranges should be nested inside the op.

After discussion, we decide to keep all the range info in cuf.kernel. Scheduling of loops will be handled later based on loop number in cuf.kernel.

Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@wangzpgi wangzpgi requested a review from clementval February 20, 2025 19:39
! CHECK: func.func @_QPdoconc2() {
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not setting the n attribute if it is one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not setting the n attribute if it is one?

Correct. If it's omitted, it's one by default.

@@ -133,6 +133,10 @@ program main
!$cuf kernel do <<< 1, 2 >>>
do concurrent (j=1:10)
end do
!ERROR: !$CUF KERNEL DO (2) must be followed by a DO construct with tightly nested outer levels of counted DO loops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to adapt the message to add do concurrent in it.

!$CUF KERNEL DO (2) must be followed by a DO construct with tightly nested outer levels of counted DO loops or a DO CONCURRENT construct with at least 2 concurrent headers

@clementval
Copy link
Contributor

Looks mostly ok. Just a comment on the error in sema and a question.

@wangzpgi wangzpgi requested a review from clementval February 20, 2025 21:32
Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a couple of auto to be spelled.

wangzpgi and others added 2 commits February 20, 2025 13:47
Co-authored-by: Valentin Clement (バレンタイン クレメン) <[email protected]>
Co-authored-by: Valentin Clement (バレンタイン クレメン) <[email protected]>
@wangzpgi wangzpgi merged commit a67566b into llvm:main Feb 20, 2025
9 of 10 checks passed
@wangzpgi wangzpgi deleted the doconc branch March 5, 2025 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants