Skip to content

Allow do concurrent inside cuf kernel directive #127693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 117 additions & 40 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3074,50 +3074,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
llvm::SmallVector<mlir::Value> ivValues;
Fortran::lower::pft::Evaluation *loopEval =
&getEval().getFirstNestedEvaluation();
for (unsigned i = 0; i < nestedLoops; ++i) {
const Fortran::parser::LoopControl *loopControl;
mlir::Location crtLoc = loc;
if (i == 0) {
loopControl = &*outerDoConstruct->GetLoopControl();
crtLoc =
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
} else {
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
if (outerDoConstruct->IsDoConcurrent()) {
// Handle DO CONCURRENT
locs.push_back(
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
const Fortran::parser::LoopControl *loopControl =
&*outerDoConstruct->GetLoopControl();
const auto &concurrent =
std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);

if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
.empty())
TODO(loc, "DO CONCURRENT with locality spec");

const auto &concurrentHeader =
std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
const auto &controls =
std::get<std::list<Fortran::parser::ConcurrentControl>>(
concurrentHeader.t);

for (const auto &control : controls) {
mlir::Value lb = fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx));
mlir::Value ub = fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx));
mlir::Value step;

if (const auto &expr =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
control.t))
step = fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
else
step = builder->create<mlir::arith::ConstantIndexOp>(
loc, 1); // Use index type directly

// Ensure lb, ub, and step are of index type using fir.convert
mlir::Type indexType = builder->getIndexType();
lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
step = builder->create<fir::ConvertOp>(loc, indexType, step);

lbs.push_back(lb);
ubs.push_back(ub);
steps.push_back(step);

const auto &name = std::get<Fortran::parser::Name>(control.t);

// Handle induction variable
mlir::Value ivValue = getSymbolAddress(*name.symbol);
std::size_t ivTypeSize = name.symbol->size();
if (ivTypeSize == 0)
llvm::report_fatal_error("unexpected induction variable size");
mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);

if (!ivValue) {
// DO CONCURRENT induction variables are not mapped yet since they are
// local to the DO CONCURRENT scope.
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
builder->setInsertionPointToStart(builder->getAllocaBlock());
ivValue = builder->createTemporaryAlloc(
loc, ivTy, toStringRef(name.symbol->name()));
builder->restoreInsertionPoint(insPt);
}

// Create the hlfir.declare operation using the symbol's name
auto declareOp = builder->create<hlfir::DeclareOp>(
loc, ivValue, toStringRef(name.symbol->name()));
ivValue = declareOp.getResult(0);

// Bind the symbol to the declared variable
bindSymbol(*name.symbol, ivValue);
ivValues.push_back(ivValue);
ivTypes.push_back(ivTy);
ivLocs.push_back(loc);
}
} else {
for (unsigned i = 0; i < nestedLoops; ++i) {
const Fortran::parser::LoopControl *loopControl;
mlir::Location crtLoc = loc;
if (i == 0) {
loopControl = &*outerDoConstruct->GetLoopControl();
crtLoc = genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct));
} else {
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
}

locs.push_back(crtLoc);

const Fortran::parser::LoopControl::Bounds *bounds =
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds on the loop construct");

Fortran::semantics::Symbol &ivSym =
bounds->name.thing.symbol->GetUltimate();
ivValues.push_back(getSymbolAddress(ivSym));

lbs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
stmtCtx))));
ubs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
stmtCtx))));
if (bounds->step)
steps.push_back(builder->createConvert(
locs.push_back(crtLoc);

const Fortran::parser::LoopControl::Bounds *bounds =
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds on the loop construct");

Fortran::semantics::Symbol &ivSym =
bounds->name.thing.symbol->GetUltimate();
ivValues.push_back(getSymbolAddress(ivSym));

lbs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
else // If `step` is not present, assume it is `1`.
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));

ivTypes.push_back(idxTy);
ivLocs.push_back(crtLoc);
if (i < nestedLoops - 1)
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
ubs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
if (bounds->step)
steps.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
else // If `step` is not present, assume it is `1`.
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));

ivTypes.push_back(idxTy);
ivLocs.push_back(crtLoc);
if (i < nestedLoops - 1)
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
}
}

auto op = builder->create<cuf::KernelOp>(
Expand Down
26 changes: 23 additions & 3 deletions flang/lib/Semantics/check-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,21 @@ static int DoConstructTightNesting(
return 0;
}
innerBlock = &std::get<parser::Block>(doConstruct->t);
if (doConstruct->IsDoConcurrent()) {
const auto &loopControl = doConstruct->GetLoopControl();
if (loopControl) {
if (const auto *concurrentControl{
std::get_if<parser::LoopControl::Concurrent>(&loopControl->u)}) {
const auto &concurrentHeader =
std::get<Fortran::parser::ConcurrentHeader>(concurrentControl->t);
const auto &controls =
std::get<std::list<Fortran::parser::ConcurrentControl>>(
concurrentHeader.t);
return controls.size();
}
}
return 0;
}
if (innerBlock->size() == 1) {
if (const auto *execConstruct{
std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
Expand Down Expand Up @@ -598,9 +613,14 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
std::get<std::optional<parser::DoConstruct>>(x.t))};
const parser::Block *innerBlock{nullptr};
if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
context_.Say(source,
"!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
std::intmax_t{depth});
if (doConstruct && doConstruct->IsDoConcurrent())
context_.Say(source,
"!$CUF KERNEL DO (%jd) must be followed by a DO CONCURRENT construct with at least %jd indices"_err_en_US,
std::intmax_t{depth}, std::intmax_t{depth});
else
context_.Say(source,
"!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
std::intmax_t{depth});
}
if (innerBlock) {
DeviceContextChecker<true>{context_}.Check(*innerBlock);
Expand Down
39 changes: 39 additions & 0 deletions flang/test/Lower/CUDA/cuda-doconc.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

! Check if do concurrent works inside cuf kernel directive

subroutine doconc1
integer :: i, n
integer, managed :: a(3)
a(:) = -1
n = 3
n = n - 1
!$cuf kernel do
do concurrent(i=1:n)
a(i) = 1
end do
end

! CHECK: func.func @_QPdoconc1() {
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.kernel<<<*, *>>>
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>

subroutine doconc2
integer :: i, j, m, n
integer, managed :: a(2, 4)
m = 2
n = 4
a(:,:) = -1
!$cuf kernel do
do concurrent(i=1:m,j=1:n)
a(i,j) = i+j
end do
end

! CHECK: func.func @_QPdoconc2() {
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the semantic of !$cuf kernel do? For me it means !$cuf kernel do(1) so only the first range is part of the cuf kernel operation and the rest should be nested inside. Let me know if the semantic is different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the semantic of !$cuf kernel do? For me it means !$cuf kernel do(1) so only the first range is part of the cuf kernel operation and the rest should be nested inside. Let me know if the semantic is different.

Yes, it's the same semantic with do and do(1), we have this multi range with cuf.kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one range should be on the cuf.kernel. The rest of the ranges should be nested inside the op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one range should be on the cuf.kernel. The rest of the ranges should be nested inside the op.

After discussion, we decide to keep all the range info in cuf.kernel. Scheduling of loops will be handled later based on loop number in cuf.kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not setting the n attribute if it is one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not setting the n attribute if it is one?

Correct. If it's omitted, it's one by default.

! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<i32>
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<i32>
4 changes: 4 additions & 0 deletions flang/test/Semantics/cuf09.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ program main
!$cuf kernel do <<< 1, 2 >>>
do concurrent (j=1:10)
end do
!ERROR: !$CUF KERNEL DO (2) must be followed by a DO CONCURRENT construct with at least 2 indices
!$cuf kernel do(2) <<< 1, 2 >>>
do concurrent (j=1:10)
end do
!$cuf kernel do <<< 1, 2 >>>
do 1 j=1,10
1 continue ! ok
Expand Down