Skip to content

Commit eb5907d

Browse files
authored
[flang][cuda] Avoid to issue data transfer in device context (#90247)
Data transfer should not be issued in device function.
1 parent 9ee8e38 commit eb5907d

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3806,16 +3806,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38063806
return temps;
38073807
}
38083808

3809+
// Check if the insertion point is currently in a device context. HostDevice
3810+
// subprogram are not considered fully device context so it will return false
3811+
// for it.
3812+
static bool isDeviceContext(fir::FirOpBuilder &builder) {
3813+
if (builder.getRegion().getParentOfType<fir::CUDAKernelOp>())
3814+
return true;
3815+
if (auto funcOp =
3816+
builder.getRegion().getParentOfType<mlir::func::FuncOp>()) {
3817+
if (auto cudaProcAttr =
3818+
funcOp.getOperation()->getAttrOfType<fir::CUDAProcAttributeAttr>(
3819+
fir::getCUDAAttrName())) {
3820+
return cudaProcAttr.getValue() != fir::CUDAProcAttribute::Host &&
3821+
cudaProcAttr.getValue() != fir::CUDAProcAttribute::HostDevice;
3822+
}
3823+
}
3824+
return false;
3825+
}
3826+
38093827
void genDataAssignment(
38103828
const Fortran::evaluate::Assignment &assign,
38113829
const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
38123830
mlir::Location loc = getCurrentLocation();
38133831
fir::FirOpBuilder &builder = getFirOpBuilder();
38143832

3815-
bool isInDeviceContext =
3816-
builder.getRegion().getParentOfType<fir::CUDAKernelOp>();
3817-
bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
3818-
Fortran::evaluate::HasCUDAAttrs(assign.rhs);
3833+
bool isInDeviceContext = isDeviceContext(builder);
3834+
bool isCUDATransfer = (Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
3835+
Fortran::evaluate::HasCUDAAttrs(assign.rhs)) &&
3836+
!isInDeviceContext;
38193837
bool hasCUDAImplicitTransfer =
38203838
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
38213839
llvm::SmallVector<mlir::Value> implicitTemps;
@@ -3878,7 +3896,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38783896
Fortran::lower::StatementContext localStmtCtx;
38793897
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
38803898
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
3881-
if (isCUDATransfer && !hasCUDAImplicitTransfer && !isInDeviceContext)
3899+
if (isCUDATransfer && !hasCUDAImplicitTransfer)
38823900
genCUDADataTransfer(builder, loc, assign, lhs, rhs);
38833901
else
38843902
builder.create<hlfir::AssignOp>(loc, rhs, lhs,

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,21 @@ end subroutine
141141
! CHECK: fir.cuda_kernel<<<*, *>>>
142142
! CHECK-NOT: fir.cuda_data_transfer
143143
! CHECK: hlfir.assign
144+
145+
attributes(global) subroutine sub5(a)
146+
integer, device :: a
147+
integer :: i
148+
a = i
149+
end subroutine
150+
151+
! CHECK-LABEL: func.func @_QPsub5
152+
! CHECK-NOT: fir.cuda_data_transfer
153+
154+
attributes(host,device) subroutine sub6(a)
155+
integer, device :: a
156+
integer :: i
157+
a = i
158+
end subroutine
159+
160+
! CHECK-LABEL: func.func @_QPsub6
161+
! CHECK: fir.cuda_data_transfer

0 commit comments

Comments
 (0)