From e5c75de8577d943eab6069f52c9b785192201a12 Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Mon, 5 Oct 2020 17:29:24 -0700 Subject: [PATCH 1/7] [SYCL] Add KernelNameTypeVisitor validation check --- clang/lib/Sema/SemaSYCL.cpp | 65 +++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 6e25ae7a6a975..5497ac26ddac5 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2829,24 +2829,27 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { } // namespace class SYCLKernelNameTypeVisitor - : public TypeVisitor, - public ConstTemplateArgumentVisitor { + : public TypeVisitor, + public ConstTemplateArgumentVisitor { Sema &S; SourceLocation KernelInvocationFuncLoc; - using InnerTypeVisitor = TypeVisitor; + using InnerTypeVisitor = TypeVisitor; using InnerTAVisitor = - ConstTemplateArgumentVisitor; + ConstTemplateArgumentVisitor; + bool IsInvalid = false; public: SYCLKernelNameTypeVisitor(Sema &S, SourceLocation KernelInvocationFuncLoc) : S(S), KernelInvocationFuncLoc(KernelInvocationFuncLoc) {} - void Visit(QualType T) { + bool isValid() { return !IsInvalid; } + + bool Visit(QualType T) { if (T.isNull()) - return; + return false; const CXXRecordDecl *RD = T->getAsCXXRecordDecl(); if (!RD) - return; + return false; // If KernelNameType has template args visit each template arg via // ConstTemplateArgumentVisitor if (const auto *TSD = dyn_cast(RD)) { @@ -2857,29 +2860,33 @@ class SYCLKernelNameTypeVisitor } else { InnerTypeVisitor::Visit(T.getTypePtr()); } + return true; } - void Visit(const TemplateArgument &TA) { + bool Visit(const TemplateArgument &TA) { if (TA.isNull()) - return; - InnerTAVisitor::Visit(TA); + return false; + return InnerTAVisitor::Visit(TA); } - void VisitEnumType(const EnumType *T) { + bool VisitEnumType(const EnumType *T) { const EnumDecl *ED = T->getDecl(); if (!ED->isScoped() && !ED->isFixed()) { S.Diag(KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named) << /* Unscoped enum requires fixed underlying type */ 2; S.Diag(ED->getSourceRange().getBegin(), diag::note_entity_declared_at) << ED; + IsInvalid = true; + return isValid(); } + return true; } - void VisitRecordType(const RecordType *T) { + bool VisitRecordType(const RecordType *T) { return VisitTagDecl(T->getDecl()); } - void VisitTagDecl(const TagDecl *Tag) { + bool VisitTagDecl(const TagDecl *Tag) { bool UnnamedLambdaEnabled = S.getASTContext().getLangOpts().SYCLUnnamedLambda; if (!Tag->getDeclContext()->isTranslationUnit() && @@ -2888,43 +2895,49 @@ class SYCLKernelNameTypeVisitor if (KernelNameIsMissing) { S.Diag(KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named) << /* kernel name is missing */ 0; + IsInvalid = true; } else { - if (Tag->isCompleteDefinition()) + if (Tag->isCompleteDefinition()) { S.Diag(KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named) << /* kernel name is not globally-visible */ 1; - else + IsInvalid = true; + } else S.Diag(KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl); S.Diag(Tag->getSourceRange().getBegin(), diag::note_previous_decl) << Tag->getName(); } + return isValid(); } + return true; } - void VisitTypeTemplateArgument(const TemplateArgument &TA) { + bool VisitTypeTemplateArgument(const TemplateArgument &TA) { QualType T = TA.getAsType(); if (const auto *ET = T->getAs()) - VisitEnumType(ET); + return VisitEnumType(ET); else - Visit(T); + return Visit(T); } - void VisitIntegralTemplateArgument(const TemplateArgument &TA) { + bool VisitIntegralTemplateArgument(const TemplateArgument &TA) { QualType T = TA.getIntegralType(); if (const EnumType *ET = T->getAs()) - VisitEnumType(ET); + return VisitEnumType(ET); + return true; } - void VisitTemplateTemplateArgument(const TemplateArgument &TA) { + bool VisitTemplateTemplateArgument(const TemplateArgument &TA) { TemplateDecl *TD = TA.getAsTemplate().getAsTemplateDecl(); TemplateParameterList *TemplateParams = TD->getTemplateParameters(); for (NamedDecl *P : *TemplateParams) { if (NonTypeTemplateParmDecl *TemplateParam = dyn_cast(P)) if (const EnumType *ET = TemplateParam->getType()->getAs()) - VisitEnumType(ET); + return VisitEnumType(ET); } + return true; } }; @@ -2970,10 +2983,10 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc, SyclKernelArgsSizeChecker ArgsSizeChecker(*this, Args[0]->getExprLoc()); KernelObjVisitor Visitor{*this}; - SYCLKernelNameTypeVisitor KernelTypeVisitor(*this, Args[0]->getExprLoc()); + SYCLKernelNameTypeVisitor KernelNameTypeVisitor(*this, Args[0]->getExprLoc()); // Emit diagnostics for SYCL device kernels only if (LangOpts.SYCLIsDevice) - KernelTypeVisitor.Visit(KernelNameType); + KernelNameTypeVisitor.Visit(KernelNameType); DiagnosingSYCLKernel = true; Visitor.VisitRecordBases(KernelObj, FieldChecker, UnionChecker, DecompMarker); Visitor.VisitRecordFields(KernelObj, FieldChecker, UnionChecker, @@ -2987,7 +3000,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc, Visitor.VisitRecordFields(KernelObj, ArgsSizeChecker); } DiagnosingSYCLKernel = false; - if (!FieldChecker.isValid() || !UnionChecker.isValid()) + // Set the kernel function as invalid, if any of the checkers fail validation. + if (!FieldChecker.isValid() || !UnionChecker.isValid() || + !KernelNameTypeVisitor.isValid()) KernelFunc->setInvalidDecl(); } From 52bd9a3e26a23a91ccd27a7bcfae7fbccc735791 Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Mon, 5 Oct 2020 18:39:20 -0700 Subject: [PATCH 2/7] Move 'DiagnosingSYCLKernel' before invoking visitors --- clang/lib/Sema/SemaSYCL.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 5497ac26ddac5..9d4e7c5615bbd 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2984,10 +2984,12 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc, KernelObjVisitor Visitor{*this}; SYCLKernelNameTypeVisitor KernelNameTypeVisitor(*this, Args[0]->getExprLoc()); + + DiagnosingSYCLKernel = true; + // Emit diagnostics for SYCL device kernels only if (LangOpts.SYCLIsDevice) KernelNameTypeVisitor.Visit(KernelNameType); - DiagnosingSYCLKernel = true; Visitor.VisitRecordBases(KernelObj, FieldChecker, UnionChecker, DecompMarker); Visitor.VisitRecordFields(KernelObj, FieldChecker, UnionChecker, DecompMarker); From 49c8ba526f3094039e9503876ee12919ca9eb742 Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Tue, 6 Oct 2020 08:15:59 -0700 Subject: [PATCH 3/7] Fix code review comments --- clang/lib/Sema/SemaSYCL.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 9d4e7c5615bbd..2f671eff4067b 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2845,11 +2845,9 @@ class SYCLKernelNameTypeVisitor bool isValid() { return !IsInvalid; } bool Visit(QualType T) { - if (T.isNull()) - return false; + assert(!T.isNull() && "KernelNameType cannot be null"); const CXXRecordDecl *RD = T->getAsCXXRecordDecl(); - if (!RD) - return false; + assert(RD && "KernelNameType is not a record"); // If KernelNameType has template args visit each template arg via // ConstTemplateArgumentVisitor if (const auto *TSD = dyn_cast(RD)) { @@ -2864,8 +2862,7 @@ class SYCLKernelNameTypeVisitor } bool Visit(const TemplateArgument &TA) { - if (TA.isNull()) - return false; + assert(!TA.isNull() && "TemplateArgument cannot be null"); return InnerTAVisitor::Visit(TA); } @@ -2917,8 +2914,7 @@ class SYCLKernelNameTypeVisitor QualType T = TA.getAsType(); if (const auto *ET = T->getAs()) return VisitEnumType(ET); - else - return Visit(T); + return Visit(T); } bool VisitIntegralTemplateArgument(const TemplateArgument &TA) { From b953d80fa5a0d04d46deebc8b22279a9a88c591b Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Tue, 6 Oct 2020 09:06:23 -0700 Subject: [PATCH 4/7] Undo assert --- clang/lib/Sema/SemaSYCL.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 2f671eff4067b..88e23d3b4304c 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2845,9 +2845,11 @@ class SYCLKernelNameTypeVisitor bool isValid() { return !IsInvalid; } bool Visit(QualType T) { - assert(!T.isNull() && "KernelNameType cannot be null"); + if (T.isNull()) + return false; const CXXRecordDecl *RD = T->getAsCXXRecordDecl(); - assert(RD && "KernelNameType is not a record"); + if (!RD) + return false; // If KernelNameType has template args visit each template arg via // ConstTemplateArgumentVisitor if (const auto *TSD = dyn_cast(RD)) { @@ -2862,7 +2864,8 @@ class SYCLKernelNameTypeVisitor } bool Visit(const TemplateArgument &TA) { - assert(!TA.isNull() && "TemplateArgument cannot be null"); + if (TA.isNull()) + return false; return InnerTAVisitor::Visit(TA); } From 7da6c64d77af420bddf0e9d2e441d83dc21f540e Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Wed, 7 Oct 2020 10:30:16 -0700 Subject: [PATCH 5/7] Reset return type of base visitor classes to void --- clang/lib/Sema/SemaSYCL.cpp | 47 ++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 88e23d3b4304c..8c0dcf920ccaf 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2829,13 +2829,13 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { } // namespace class SYCLKernelNameTypeVisitor - : public TypeVisitor, - public ConstTemplateArgumentVisitor { + : public TypeVisitor, + public ConstTemplateArgumentVisitor { Sema &S; SourceLocation KernelInvocationFuncLoc; - using InnerTypeVisitor = TypeVisitor; + using InnerTypeVisitor = TypeVisitor; using InnerTAVisitor = - ConstTemplateArgumentVisitor; + ConstTemplateArgumentVisitor; bool IsInvalid = false; public: @@ -2844,12 +2844,12 @@ class SYCLKernelNameTypeVisitor bool isValid() { return !IsInvalid; } - bool Visit(QualType T) { + void Visit(QualType T) { if (T.isNull()) - return false; + return; const CXXRecordDecl *RD = T->getAsCXXRecordDecl(); if (!RD) - return false; + return; // If KernelNameType has template args visit each template arg via // ConstTemplateArgumentVisitor if (const auto *TSD = dyn_cast(RD)) { @@ -2860,16 +2860,15 @@ class SYCLKernelNameTypeVisitor } else { InnerTypeVisitor::Visit(T.getTypePtr()); } - return true; } - bool Visit(const TemplateArgument &TA) { + void Visit(const TemplateArgument &TA) { if (TA.isNull()) - return false; - return InnerTAVisitor::Visit(TA); + return; + InnerTAVisitor::Visit(TA); } - bool VisitEnumType(const EnumType *T) { + void VisitEnumType(const EnumType *T) { const EnumDecl *ED = T->getDecl(); if (!ED->isScoped() && !ED->isFixed()) { S.Diag(KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named) @@ -2877,16 +2876,14 @@ class SYCLKernelNameTypeVisitor S.Diag(ED->getSourceRange().getBegin(), diag::note_entity_declared_at) << ED; IsInvalid = true; - return isValid(); } - return true; } - bool VisitRecordType(const RecordType *T) { + void VisitRecordType(const RecordType *T) { return VisitTagDecl(T->getDecl()); } - bool VisitTagDecl(const TagDecl *Tag) { + void VisitTagDecl(const TagDecl *Tag) { bool UnnamedLambdaEnabled = S.getASTContext().getLangOpts().SYCLUnnamedLambda; if (!Tag->getDeclContext()->isTranslationUnit() && @@ -2908,35 +2905,31 @@ class SYCLKernelNameTypeVisitor S.Diag(Tag->getSourceRange().getBegin(), diag::note_previous_decl) << Tag->getName(); } - return isValid(); } - return true; } - bool VisitTypeTemplateArgument(const TemplateArgument &TA) { + void VisitTypeTemplateArgument(const TemplateArgument &TA) { QualType T = TA.getAsType(); if (const auto *ET = T->getAs()) - return VisitEnumType(ET); - return Visit(T); + VisitEnumType(ET); + Visit(T); } - bool VisitIntegralTemplateArgument(const TemplateArgument &TA) { + void VisitIntegralTemplateArgument(const TemplateArgument &TA) { QualType T = TA.getIntegralType(); if (const EnumType *ET = T->getAs()) - return VisitEnumType(ET); - return true; + VisitEnumType(ET); } - bool VisitTemplateTemplateArgument(const TemplateArgument &TA) { + void VisitTemplateTemplateArgument(const TemplateArgument &TA) { TemplateDecl *TD = TA.getAsTemplate().getAsTemplateDecl(); TemplateParameterList *TemplateParams = TD->getTemplateParameters(); for (NamedDecl *P : *TemplateParams) { if (NonTypeTemplateParmDecl *TemplateParam = dyn_cast(P)) if (const EnumType *ET = TemplateParam->getType()->getAs()) - return VisitEnumType(ET); + VisitEnumType(ET); } - return true; } }; From 19aacc7b0e88c5f0f5befea8879740c275bf0e8f Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Thu, 8 Oct 2020 07:57:04 -0700 Subject: [PATCH 6/7] Add the "else" branch back Co-authored-by: Mariya Podchishchaeva --- clang/lib/Sema/SemaSYCL.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 8c0dcf920ccaf..f3838a20f6baa 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2912,6 +2912,7 @@ class SYCLKernelNameTypeVisitor QualType T = TA.getAsType(); if (const auto *ET = T->getAs()) VisitEnumType(ET); + else Visit(T); } From 8ca6ac726ae4125dab71696ba65e797867bbad35 Mon Sep 17 00:00:00 2001 From: Srividya Sundaram Date: Thu, 8 Oct 2020 08:04:58 -0700 Subject: [PATCH 7/7] Fix clang-format error introduced by github ui --- clang/lib/Sema/SemaSYCL.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index f3838a20f6baa..6e8fb9fd139e7 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -2912,8 +2912,8 @@ class SYCLKernelNameTypeVisitor QualType T = TA.getAsType(); if (const auto *ET = T->getAs()) VisitEnumType(ET); - else - Visit(T); + else + Visit(T); } void VisitIntegralTemplateArgument(const TemplateArgument &TA) {