diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index cad1b634f8924..814fe4cccede6 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1286,16 +1286,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &, const std::optional &, const std::string &procName, const std::string &argName); -inline bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) { - if (const auto *details = - sym.GetUltimate().detailsIf()) { - if (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Unified) { - return false; - } - } - return true; -} +bool CanCUDASymbolHaveSaveAttr(const Symbol &sym); inline bool IsCUDADeviceSymbol(const Symbol &sym) { if (const auto *details = diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h index f3cfa9b99fb4d..ca58e1065d3e5 100644 --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -654,6 +654,8 @@ DirectComponentIterator::const_iterator FindAllocatableOrPointerDirectComponent( const DerivedTypeSpec &); PotentialComponentIterator::const_iterator FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &); +UltimateComponentIterator::const_iterator +FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &); // The LabelEnforce class (given a set of labels) provides an error message if // there is a branch to a label which is not in the given set. diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index fcacdb93d662b..d912fc167a62e 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -2173,6 +2173,25 @@ bool IsAutomatic(const Symbol &original) { return false; } +bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) { + if (const auto *details{ + sym.GetUltimate().detailsIf()}) { + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{ + type ? type->AsDerived() : nullptr}; + if (derived) { + if (FindCUDADeviceAllocatableUltimateComponent(*derived)) { + return false; + } + } + if (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Unified) { + return false; + } + } + return true; +} + bool IsSaved(const Symbol &original) { const Symbol &symbol{GetAssociationRoot(original)}; const Scope &scope{symbol.owner()}; @@ -2195,7 +2214,7 @@ bool IsSaved(const Symbol &original) { } else if (scopeKind == Scope::Kind::Module || (scopeKind == Scope::Kind::MainProgram && (symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)) && - Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol))) { + CanCUDASymbolHaveSaveAttr(symbol))) { // 8.5.16p4 // In main programs, implied SAVE matters only for pointer // initialization targets and coarrays. @@ -2205,7 +2224,7 @@ bool IsSaved(const Symbol &original) { (features.IsEnabled( common::LanguageFeature::SaveBigMainProgramVariables) && symbol.size() > 32)) && - Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol)) { + CanCUDASymbolHaveSaveAttr(symbol)) { // With SaveBigMainProgramVariables, keeping all unsaved main program // variables of 32 bytes or less on the stack allows keeping numerical and // logical scalars, small scalar characters or derived, small arrays, and @@ -2223,15 +2242,15 @@ bool IsSaved(const Symbol &original) { } else if (symbol.test(Symbol::Flag::InDataStmt)) { return true; } else if (const auto *object{symbol.detailsIf()}; - object && object->init()) { + object && object->init()) { return true; } else if (IsProcedurePointer(symbol) && symbol.has() && symbol.get().init()) { return true; } else if (scope.hasSAVE()) { return true; // bare SAVE statement - } else if (const Symbol * block{FindCommonBlockContaining(symbol)}; - block && block->attrs().test(Attr::SAVE)) { + } else if (const Symbol *block{FindCommonBlockContaining(symbol)}; + block && block->attrs().test(Attr::SAVE)) { return true; // in COMMON with SAVE } else { return false; diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index d053179448c00..498bbc18709ab 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -1081,6 +1081,19 @@ const Scope *FindCUDADeviceContext(const Scope *scope) { }); } +bool IsDeviceAllocatable(const Symbol &symbol) { + if (IsAllocatable(symbol)) { + if (const auto *details{ + symbol.GetUltimate().detailsIf()}) { + if (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { + return true; + } + } + } + return false; +} + std::optional GetCUDADataAttr(const Symbol *symbol) { const auto *object{ symbol ? symbol->detailsIf() : nullptr}; @@ -1426,6 +1439,12 @@ FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &derived) { potentials.begin(), potentials.end(), IsPolymorphicAllocatable); } +UltimateComponentIterator::const_iterator +FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) { + UltimateComponentIterator ultimates{derived}; + return std::find_if(ultimates.begin(), ultimates.end(), IsDeviceAllocatable); +} + const Symbol *FindUltimateComponent(const DerivedTypeSpec &derived, const std::function &predicate) { UltimateComponentIterator ultimates{derived}; @@ -1788,4 +1807,4 @@ bool HadUseError( } } -} // namespace Fortran::semantics \ No newline at end of file +} // namespace Fortran::semantics diff --git a/flang/test/Lower/CUDA/cuda-derived.cuf b/flang/test/Lower/CUDA/cuda-derived.cuf new file mode 100644 index 0000000000000..d280ac722d08f --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-derived.cuf @@ -0,0 +1,20 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +module m1 + type ty_device + integer, device, allocatable, dimension(:) :: x + end type + + type t1; real, device, allocatable :: a(:); end type + type t2; type(t1) :: b; end type +end module + +program main + use m1 + type(ty_device) :: a + type(t2) :: b +end + +! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"} +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box>>}> {bindc_name = "a", uniq_name = "_QFEa"} +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}