Skip to content

Commit 68c5111

Browse files
committed
clang format
Signed-off-by: gejin <[email protected]>
1 parent 1d53420 commit 68c5111

File tree

1 file changed

+90
-59
lines changed

1 file changed

+90
-59
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <array>
3232
#include <functional>
3333
#include <initializer_list>
34-
// #include <iostream>
3534

3635
using namespace clang;
3736
using namespace std::placeholders;
@@ -106,8 +105,6 @@ class Util {
106105
/// \param Name the function name to be checked against.
107106
static bool isSyclFunction(const FunctionDecl *FD, StringRef Name);
108107

109-
// static void printSyclFunction(const FunctionDecl *FD);
110-
111108
/// Checks whether given clang type is a full specialization of the SYCL
112109
/// specialization constant class.
113110
static bool isSyclSpecConstantType(const QualType &Ty);
@@ -368,7 +365,8 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
368365
Callee = Callee->getCanonicalDecl();
369366
assert(Callee && "Device function canonical decl must be available");
370367

371-
bool RecursionAllowed = TraverseEsimdKernel && Callee->hasAttr<NoInlineAttr>();
368+
bool RecursionAllowed =
369+
TraverseEsimdKernel && Callee->hasAttr<NoInlineAttr>();
372370
// Remember that all SYCL kernel functions have deferred
373371
// instantiation as template functions. It means that
374372
// all functions used by kernel have already been parsed and have
@@ -469,60 +467,83 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
469467

470468
// The call graph for this translation unit.
471469
CallGraph SYCLCG;
472-
// The set of functions called by a kernel function.
473-
// llvm::SmallPtrSet<FunctionDecl *, 10> KernelSet;
474-
llvm::DenseMap<FunctionDecl *, llvm::SmallPtrSet<FunctionDecl *, 10>> SYCLKernelInvokeMap;
475-
// The set of recursive functions identified while building the
476-
// kernel set, this is used for error diagnostics.
470+
// Record the mapping between each SYCL kernel or SYCL_EXTERNAL function and
471+
// functions called by it.
472+
llvm::DenseMap<FunctionDecl *, llvm::SmallPtrSet<FunctionDecl *, 10>>
473+
SYCLKernelInvokeMap;
474+
// The set of recursive functions identified while going through the call
475+
// graph for each SYCL kernel or SYCL_EXTERNAL function, this is used for
476+
// error diagnostics.
477477
llvm::SmallPtrSet<FunctionDecl *, 10> RecursiveSet;
478478

479-
void WalkSYCLKernelCGNode(FunctionDecl *KernelNode, llvm::SmallPtrSet<FunctionDecl *, 10> &VisitedSet) {
480-
CallGraphNode *CGN = SYCLCG.getNode(KernelNode);
481-
if (!CGN) return;
482-
for (const CallGraphNode *CI : *CGN) {
483-
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
479+
// Traverse over call graph to find all functions directly or indirectly
480+
// called by SYCLFunc, all functions called are recorded in VisitedSet.
481+
// Each time a function is visited, check if it is a "cyclic" function
482+
// which means the function will directly or indrectly call itself.
483+
// All "cyclic" functions are recorded in RecursiveSet for later recursion
484+
// error diagnostics.
485+
void
486+
WalkSYCLFunctionCG(FunctionDecl *SYCLFunc,
487+
llvm::SmallPtrSet<FunctionDecl *, 10> &VisitedSet,
488+
llvm::SmallPtrSet<FunctionDecl *, 10> CyclicCheckSet) {
489+
CallGraphNode *SYCLFuncCGN = SYCLCG.getNode(SYCLFunc);
490+
if (!SYCLFuncCGN)
491+
return;
492+
for (const CallGraphNode *CGN : *SYCLFuncCGN) {
493+
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CGN->getDecl())) {
484494
Callee = Callee->getCanonicalDecl();
485495
if (VisitedSet.count(Callee) == 0) {
486-
// Util::printSyclFunction(Callee);
487496
VisitedSet.insert(Callee);
488-
if (IsCylicSYCLFunction(Callee)) {
489-
RecursiveSet.insert(Callee);
497+
if (CyclicCheckSet.count(Callee) == 0) {
498+
if (IsCyclicSYCLFunction(Callee)) {
499+
RecursiveSet.insert(Callee);
500+
}
501+
CyclicCheckSet.insert(Callee);
490502
}
491-
WalkSYCLKernelCGNode(Callee, VisitedSet);
503+
WalkSYCLFunctionCG(Callee, VisitedSet, CyclicCheckSet);
492504
}
493505
}
494506
}
495507
}
496508

497-
bool IsCylicSYCLFunction(FunctionDecl *SYCLFunc) {
498-
CallGraphNode *CGN = SYCLCG.getNode(SYCLFunc);
499-
if (!CGN) return false;
509+
// Traverses over call graph to find whether a SYCL function will directly
510+
// or indirectly call itself.
511+
bool IsCyclicSYCLFunction(FunctionDecl *SYCLFunc) {
512+
CallGraphNode *SYCLFuncCGN = SYCLCG.getNode(SYCLFunc);
513+
if (!SYCLFuncCGN)
514+
return false;
500515
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
501-
for (const CallGraphNode *CI : *CGN) {
502-
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
516+
for (const CallGraphNode *CGN : *SYCLFuncCGN) {
517+
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CGN->getDecl())) {
503518
Callee = Callee->getCanonicalDecl();
504-
if (SYCLFunctionVisit(Callee, SYCLFunc, VisitedSet)) return true;
519+
if (IsSYCLFunctionInvoke(Callee, SYCLFunc, VisitedSet))
520+
return true;
505521
}
506522
}
507523
return false;
508524
}
509525

510-
bool SYCLFunctionVisit(FunctionDecl *FDSrc, FunctionDecl *FDDes, llvm::SmallPtrSet<FunctionDecl *, 10> &VisitedSet) {
511-
if (FDSrc == FDDes) return true;
512-
CallGraphNode *CGN = SYCLCG.getNode(FDSrc);
513-
if (!CGN) return false;
514-
for (const CallGraphNode *CI : *CGN) {
515-
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
526+
// Traverse over call graph to find whether FDSrc wil directly or indrectly
527+
// call FDDes.
528+
bool IsSYCLFunctionInvoke(FunctionDecl *FDSrc, FunctionDecl *FDDes,
529+
llvm::SmallPtrSet<FunctionDecl *, 10> &VisitedSet) {
530+
if (FDSrc == FDDes)
531+
return true;
532+
CallGraphNode *SYCLFuncCGN = SYCLCG.getNode(FDSrc);
533+
if (!SYCLFuncCGN)
534+
return false;
535+
for (const CallGraphNode *CGN : *SYCLFuncCGN) {
536+
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CGN->getDecl())) {
516537
Callee = Callee->getCanonicalDecl();
517538
if (VisitedSet.count(Callee) == 0) {
518539
VisitedSet.insert(Callee);
519-
if (SYCLFunctionVisit(Callee, FDDes, VisitedSet)) return true;
540+
if (IsSYCLFunctionInvoke(Callee, FDDes, VisitedSet))
541+
return true;
520542
}
521543
}
522544
}
523545
return false;
524546
}
525-
526547
void SetTraverseEsimdKernel(bool TEsimdK) { TraverseEsimdKernel = TEsimdK; }
527548
bool IsTraverseEsimdKernel() const { return TraverseEsimdKernel; }
528549
// Traverses over CallGraph to collect list of attributes applied to
@@ -3301,15 +3322,15 @@ void Sema::MarkDevice(void) {
33013322
}
33023323
}
33033324

3325+
llvm::SmallPtrSet<FunctionDecl *, 10> SYCLFunctionCyclicCheckSet;
33043326
for (Decl *D : syclDeviceDecls()) {
33053327
if (auto SYCLKernel = dyn_cast<FunctionDecl>(D)) {
3306-
// Util::printSyclFunction(SYCLKernel);
33073328
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
3308-
// Marker.CollectKernelSet(SYCLKernel, SYCLKernel, VisitedSet);
33093329
llvm::SmallPtrSet<FunctionDecl *, 10> SYCLKernelInvokeSet;
3310-
llvm::SmallPtrSet<FunctionDecl *, 10> RecurSet;
3311-
Marker.WalkSYCLKernelCGNode(SYCLKernel, SYCLKernelInvokeSet);
3312-
Marker.SYCLKernelInvokeMap.insert(std::make_pair(SYCLKernel, SYCLKernelInvokeSet));
3330+
Marker.WalkSYCLFunctionCG(SYCLKernel, SYCLKernelInvokeSet,
3331+
SYCLFunctionCyclicCheckSet);
3332+
Marker.SYCLKernelInvokeMap.insert(
3333+
std::make_pair(SYCLKernel, SYCLKernelInvokeSet));
33133334

33143335
// Let's propagate attributes from device functions to a SYCL kernels
33153336
llvm::SmallVector<Attr *, 4> Attrs;
@@ -3426,19 +3447,44 @@ void Sema::MarkDevice(void) {
34263447
}
34273448
}
34283449

3429-
llvm::SmallPtrSet<FunctionDecl *, 10> NoEsimdModeVisited;
3430-
llvm::SmallPtrSet<FunctionDecl *, 10> EsimdModeVisited;
3450+
// Previously, we traverse all SYCL functions including kernel functions
3451+
// and functions called by kernel to do some sema check following same rules.
3452+
// But this mechanism can't meet our requirements now as different rules may
3453+
// be required for different type of SYCL kernel. For example, recursion is
3454+
// not allowed for normal SYCL kernel but recursive function with "noinline"
3455+
// attribute is permitted in SYCL ESIMD kernel.
3456+
// Now, we traverse each SYCL kernel together with all functions called by it
3457+
// directly or indirectly. Before the sema check, we will check kernel type
3458+
// to decide whether different rules are used. Currently, only normal SYCL
3459+
// kernel and SYCL ESIMD kernel are supported here. A internal flag in Marker
3460+
// is used to indicate whether current check is on SYCL ESIMD kernel and all
3461+
// functions it call directly or indirectly.
3462+
// The traverse will introduce unnecessary duplicate checks as a function can
3463+
// be called in different SYCL kernels. For example, if "kernel1" and
3464+
// "kernel2" both call function "foo" and these 2 kernels are same type, we
3465+
// firstly check "kernel1" with "foo" and when we check "kernel2", "foo" will
3466+
// be checked again but it is unnecessary as the rules are same with previous
3467+
// check. In order to avoid unnecessary checks, we use 2 set to record all
3468+
// checked functions in different mode, only the one not in set will be
3469+
// checked.
3470+
llvm::SmallPtrSet<FunctionDecl *, 10> NonESIMDModeCheckSet;
3471+
llvm::SmallPtrSet<FunctionDecl *, 10> ESIMDModeCheckSet;
34313472
for (Decl *D : syclDeviceDecls()) {
34323473
if (auto SYCLKernelDec = dyn_cast<FunctionDecl>(D)) {
34333474
if (FunctionDecl *SYCLKernelDef = SYCLKernelDec->getDefinition()) {
34343475
Marker.SetTraverseEsimdKernel(SYCLKernelDef->hasAttr<SYCLSimdAttr>());
3435-
auto &VisitedSYCLFunctionSet = (SYCLKernelDef->hasAttr<SYCLSimdAttr>() ? EsimdModeVisited : NoEsimdModeVisited);
3476+
auto &SYCLFunctionCheckSet =
3477+
(SYCLKernelDef->hasAttr<SYCLSimdAttr>() ? ESIMDModeCheckSet
3478+
: NonESIMDModeCheckSet);
34363479
Marker.TraverseStmt(SYCLKernelDef->getBody());
3437-
for (FunctionDecl * SYCLFunctionDec : Marker.SYCLKernelInvokeMap[SYCLKernelDec]) {
3438-
if (VisitedSYCLFunctionSet.count(SYCLFunctionDec) != 0) continue;
3439-
if (FunctionDecl *SYCLFunctionDef = SYCLFunctionDec->getDefinition()) {
3480+
for (FunctionDecl *SYCLFunctionDec :
3481+
Marker.SYCLKernelInvokeMap[SYCLKernelDec]) {
3482+
if (SYCLFunctionCheckSet.count(SYCLFunctionDec) != 0)
3483+
continue;
3484+
if (FunctionDecl *SYCLFunctionDef =
3485+
SYCLFunctionDec->getDefinition()) {
34403486
Marker.TraverseStmt(SYCLFunctionDef->getBody());
3441-
VisitedSYCLFunctionSet.insert(SYCLFunctionDec);
3487+
SYCLFunctionCheckSet.insert(SYCLFunctionDec);
34423488
}
34433489
}
34443490
}
@@ -4174,21 +4220,6 @@ bool Util::isSyclFunction(const FunctionDecl *FD, StringRef Name) {
41744220
return matchContext(DC, Scopes);
41754221
}
41764222

4177-
/* void Util::printSyclFunction(const FunctionDecl *FD) {
4178-
if (FD != nullptr && FD->isFunctionOrMethod() && FD->getIdentifier() &&
4179-
!FD->getName().empty()) {
4180-
std::cout << "sycl function: " << FD->getName().str() << std::endl;
4181-
return;
4182-
}
4183-
4184-
if (FD == nullptr) std::cout << "null FD" << std::endl;
4185-
else if (!FD->isFunctionOrMethod()) std::cout << "FD is not function or method" << std::endl;
4186-
else if (!FD->getIdentifier()) std::cout << "FD has no identifier" << std::endl;
4187-
else if (FD->getName().empty()) std::cout << "FD empty name" << std::endl;
4188-
else;
4189-
return;
4190-
}*/
4191-
41924223
bool Util::isAccessorPropertyListType(const QualType &Ty) {
41934224
const StringRef &Name = "accessor_property_list";
41944225
std::array<DeclContextDesc, 4> Scopes = {

0 commit comments

Comments
 (0)