@@ -99,10 +99,23 @@ class Util {
99
99
// / \param Tmpl whether the class is template instantiation or simple record
100
100
static bool isSyclType (const QualType &Ty, StringRef Name, bool Tmpl = false );
101
101
102
+ // / Checks whether given function is a standard SYCL API function with given
103
+ // / name.
104
+ // / \param FD the function being checked.
105
+ // / \param Name the function name to be checked against.
106
+ static bool isSyclFunction (const FunctionDecl *FD, StringRef Name);
107
+
102
108
// / Checks whether given clang type is a full specialization of the SYCL
103
109
// / specialization constant class.
104
110
static bool isSyclSpecConstantType (const QualType &Ty);
105
111
112
+ // Checks declaration context hierarchy.
113
+ // / \param DC the context of the item to be checked.
114
+ // / \param Scopes the declaration scopes leading from the item context to the
115
+ // / translation unit (excluding the latter)
116
+ static bool matchContext (const DeclContext *DC,
117
+ ArrayRef<Util::DeclContextDesc> Scopes);
118
+
106
119
// / Checks whether given clang type is declared in the given hierarchy of
107
120
// / declaration contexts.
108
121
// / \param Ty the clang type being checked
@@ -487,6 +500,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
487
500
FunctionDecl *FD = WorkList.back ().first ;
488
501
FunctionDecl *ParentFD = WorkList.back ().second ;
489
502
503
+ // To implement rounding-up of a parallel-for range the
504
+ // SYCL header implementation modifies the kernel call like this:
505
+ // auto Wrapper = [=](TransformedArgType Arg) {
506
+ // if (Arg[0] >= NumWorkItems[0])
507
+ // return;
508
+ // Arg.set_allowed_range(NumWorkItems);
509
+ // KernelFunc(Arg);
510
+ // };
511
+ //
512
+ // This transformation leads to a condition where a kernel body
513
+ // function becomes callable from a new kernel body function.
514
+ // Hence this test.
515
+ if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction (FD))
516
+ KernelBody = FD;
517
+
490
518
if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction (FD)) {
491
519
assert (!KernelBody && " inconsistent call graph - only one kernel body "
492
520
" function can be called" );
@@ -2667,15 +2695,63 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
2667
2695
return !SemaRef.getASTContext ().hasSameType (FD->getType (), Ty);
2668
2696
}
2669
2697
2698
+ // Sets a flag if the kernel is a parallel_for that calls the
2699
+ // free function API "this_item".
2700
+ void setThisItemIsCalled (const CXXRecordDecl *KernelObj,
2701
+ FunctionDecl *KernelFunc) {
2702
+ if (getKernelInvocationKind (KernelFunc) != InvokeParallelFor)
2703
+ return ;
2704
+
2705
+ const CXXMethodDecl *WGLambdaFn = getOperatorParens (KernelObj);
2706
+ if (!WGLambdaFn)
2707
+ return ;
2708
+
2709
+ // The call graph for this translation unit.
2710
+ CallGraph SYCLCG;
2711
+ SYCLCG.addToCallGraph (SemaRef.getASTContext ().getTranslationUnitDecl ());
2712
+ using ChildParentPair =
2713
+ std::pair<const FunctionDecl *, const FunctionDecl *>;
2714
+ llvm::SmallPtrSet<const FunctionDecl *, 16 > Visited;
2715
+ llvm::SmallVector<ChildParentPair, 16 > WorkList;
2716
+ WorkList.push_back ({WGLambdaFn, nullptr });
2717
+
2718
+ while (!WorkList.empty ()) {
2719
+ const FunctionDecl *FD = WorkList.back ().first ;
2720
+ WorkList.pop_back ();
2721
+ if (!Visited.insert (FD).second )
2722
+ continue ; // We've already seen this Decl
2723
+
2724
+ // Check whether this call is to sycl::this_item().
2725
+ if (Util::isSyclFunction (FD, " this_item" )) {
2726
+ Header.setCallsThisItem (true );
2727
+ return ;
2728
+ }
2729
+
2730
+ CallGraphNode *N = SYCLCG.getNode (FD);
2731
+ if (!N)
2732
+ continue ;
2733
+
2734
+ for (const CallGraphNode *CI : *N) {
2735
+ if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
2736
+ Callee = Callee->getMostRecentDecl ();
2737
+ if (!Visited.count (Callee))
2738
+ WorkList.push_back ({Callee, FD});
2739
+ }
2740
+ }
2741
+ }
2742
+ }
2743
+
2670
2744
public:
2671
2745
static constexpr const bool VisitInsideSimpleContainers = false ;
2672
2746
SyclKernelIntHeaderCreator (Sema &S, SYCLIntegrationHeader &H,
2673
2747
const CXXRecordDecl *KernelObj, QualType NameType,
2674
- StringRef Name, StringRef StableName)
2748
+ StringRef Name, StringRef StableName,
2749
+ FunctionDecl *KernelFunc)
2675
2750
: SyclKernelFieldHandler(S), Header(H) {
2676
2751
bool IsSIMDKernel = isESIMDKernelType (KernelObj);
2677
2752
Header.startKernel (Name, NameType, StableName, KernelObj->getLocation (),
2678
2753
IsSIMDKernel);
2754
+ setThisItemIsCalled (KernelObj, KernelFunc);
2679
2755
}
2680
2756
2681
2757
bool handleSyclAccessorType (const CXXRecordDecl *RD,
@@ -3123,7 +3199,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
3123
3199
SyclKernelIntHeaderCreator int_header (
3124
3200
*this , getSyclIntegrationHeader (), KernelObj,
3125
3201
calculateKernelNameType (Context, KernelCallerFunc), KernelName,
3126
- StableName);
3202
+ StableName, KernelCallerFunc );
3127
3203
3128
3204
KernelObjVisitor Visitor{*this };
3129
3205
Visitor.VisitRecordBases (KernelObj, kernel_decl, kernel_body, int_header);
@@ -3842,6 +3918,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
3842
3918
O << " __SYCL_DLL_LOCAL\n " ;
3843
3919
O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel
3844
3920
<< " ; }\n " ;
3921
+ O << " __SYCL_DLL_LOCAL\n " ;
3922
+ O << " static constexpr bool callsThisItem() { return " ;
3923
+ O << K.CallsThisItem << " ; }\n " ;
3845
3924
O << " };\n " ;
3846
3925
CurStart += N;
3847
3926
}
@@ -3900,6 +3979,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
3900
3979
SpecConsts.emplace_back (std::make_pair (IDType, IDName.str ()));
3901
3980
}
3902
3981
3982
+ void SYCLIntegrationHeader::setCallsThisItem (bool B) {
3983
+ KernelDesc *K = getCurKernelDesc ();
3984
+ assert (K && " no kernels" );
3985
+ K->CallsThisItem = B;
3986
+ }
3987
+
3903
3988
SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
3904
3989
bool _UnnamedLambdaSupport,
3905
3990
Sema &_S)
@@ -3967,6 +4052,21 @@ bool Util::isSyclType(const QualType &Ty, StringRef Name, bool Tmpl) {
3967
4052
return matchQualifiedTypeName (Ty, Scopes);
3968
4053
}
3969
4054
4055
+ bool Util::isSyclFunction (const FunctionDecl *FD, StringRef Name) {
4056
+ if (!FD->isFunctionOrMethod () || !FD->getIdentifier () ||
4057
+ FD->getName ().empty () || Name != FD->getName ())
4058
+ return false ;
4059
+
4060
+ const DeclContext *DC = FD->getDeclContext ();
4061
+ if (DC->isTranslationUnit ())
4062
+ return false ;
4063
+
4064
+ std::array<DeclContextDesc, 2 > Scopes = {
4065
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
4066
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" }};
4067
+ return matchContext (DC, Scopes);
4068
+ }
4069
+
3970
4070
bool Util::isAccessorPropertyListType (const QualType &Ty) {
3971
4071
const StringRef &Name = " accessor_property_list" ;
3972
4072
std::array<DeclContextDesc, 4 > Scopes = {
@@ -3977,21 +4077,15 @@ bool Util::isAccessorPropertyListType(const QualType &Ty) {
3977
4077
return matchQualifiedTypeName (Ty, Scopes);
3978
4078
}
3979
4079
3980
- bool Util::matchQualifiedTypeName (const QualType &Ty ,
3981
- ArrayRef<Util::DeclContextDesc> Scopes) {
3982
- // The idea: check the declaration context chain starting from the type
4080
+ bool Util::matchContext (const DeclContext *Ctx ,
4081
+ ArrayRef<Util::DeclContextDesc> Scopes) {
4082
+ // The idea: check the declaration context chain starting from the item
3983
4083
// itself. At each step check the context is of expected kind
3984
4084
// (namespace) and name.
3985
- const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
3986
-
3987
- if (!RecTy)
3988
- return false ; // only classes/structs supported
3989
- const auto *Ctx = cast<DeclContext>(RecTy);
3990
4085
StringRef Name = " " ;
3991
4086
3992
4087
for (const auto &Scope : llvm::reverse (Scopes)) {
3993
4088
clang::Decl::Kind DK = Ctx->getDeclKind ();
3994
-
3995
4089
if (DK != Scope.first )
3996
4090
return false ;
3997
4091
@@ -4005,11 +4099,21 @@ bool Util::matchQualifiedTypeName(const QualType &Ty,
4005
4099
Name = cast<NamespaceDecl>(Ctx)->getName ();
4006
4100
break ;
4007
4101
default :
4008
- llvm_unreachable (" matchQualifiedTypeName : decl kind not supported" );
4102
+ llvm_unreachable (" matchContext : decl kind not supported" );
4009
4103
}
4010
4104
if (Name != Scope.second )
4011
4105
return false ;
4012
4106
Ctx = Ctx->getParent ();
4013
4107
}
4014
4108
return Ctx->isTranslationUnit ();
4015
4109
}
4110
+
4111
+ bool Util::matchQualifiedTypeName (const QualType &Ty,
4112
+ ArrayRef<Util::DeclContextDesc> Scopes) {
4113
+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4114
+
4115
+ if (!RecTy)
4116
+ return false ; // only classes/structs supported
4117
+ const auto *Ctx = cast<DeclContext>(RecTy);
4118
+ return Util::matchContext (Ctx, Scopes);
4119
+ }
0 commit comments