@@ -3806,16 +3806,34 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3806
3806
return temps;
3807
3807
}
3808
3808
3809
+ // Check if the insertion point is currently in a device context. HostDevice
3810
+ // subprogram are not considered fully device context so it will return false
3811
+ // for it.
3812
+ static bool isDeviceContext (fir::FirOpBuilder &builder) {
3813
+ if (builder.getRegion ().getParentOfType <fir::CUDAKernelOp>())
3814
+ return true ;
3815
+ if (auto funcOp =
3816
+ builder.getRegion ().getParentOfType <mlir::func::FuncOp>()) {
3817
+ if (auto cudaProcAttr =
3818
+ funcOp.getOperation ()->getAttrOfType <fir::CUDAProcAttributeAttr>(
3819
+ fir::getCUDAAttrName ())) {
3820
+ return cudaProcAttr.getValue () != fir::CUDAProcAttribute::Host &&
3821
+ cudaProcAttr.getValue () != fir::CUDAProcAttribute::HostDevice;
3822
+ }
3823
+ }
3824
+ return false ;
3825
+ }
3826
+
3809
3827
void genDataAssignment (
3810
3828
const Fortran::evaluate::Assignment &assign,
3811
3829
const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
3812
3830
mlir::Location loc = getCurrentLocation ();
3813
3831
fir::FirOpBuilder &builder = getFirOpBuilder ();
3814
3832
3815
- bool isInDeviceContext =
3816
- builder. getRegion (). getParentOfType <fir::CUDAKernelOp>();
3817
- bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs (assign.lhs ) ||
3818
- Fortran::evaluate::HasCUDAAttrs (assign. rhs ) ;
3833
+ bool isInDeviceContext = isDeviceContext (builder);
3834
+ bool isCUDATransfer = ( Fortran::evaluate::HasCUDAAttrs (assign. lhs ) ||
3835
+ Fortran::evaluate::HasCUDAAttrs (assign.rhs )) &&
3836
+ !isInDeviceContext ;
3819
3837
bool hasCUDAImplicitTransfer =
3820
3838
Fortran::evaluate::HasCUDAImplicitTransfer (assign.rhs );
3821
3839
llvm::SmallVector<mlir::Value> implicitTemps;
@@ -3878,7 +3896,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3878
3896
Fortran::lower::StatementContext localStmtCtx;
3879
3897
hlfir::Entity rhs = evaluateRhs (localStmtCtx);
3880
3898
hlfir::Entity lhs = evaluateLhs (localStmtCtx);
3881
- if (isCUDATransfer && !hasCUDAImplicitTransfer && !isInDeviceContext )
3899
+ if (isCUDATransfer && !hasCUDAImplicitTransfer)
3882
3900
genCUDADataTransfer (builder, loc, assign, lhs, rhs);
3883
3901
else
3884
3902
builder.create <hlfir::AssignOp>(loc, rhs, lhs,
0 commit comments