diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 4789493399ec2..d43adaa4131ce 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -3952,3 +3952,18 @@ def HLSLNumThreads: InheritableAttr { let LangOpts = [HLSL]; let Documentation = [NumThreadsDocs]; } + +def HLSLShader : InheritableAttr { + let Spellings = [Microsoft<"shader">]; + let Subjects = SubjectList<[HLSLEntry]>; + let LangOpts = [HLSL]; + let Args = [EnumArgument<"Stage", "EnvironmentType", + ["pixel", "vertex", "geometry", "hull", "domain", "compute", + "library", "raygeneration", "intersection", "anyhit", + "closestHit", "miss", "callable", "mesh", "amplification"], + ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute", + "Library", "RayGeneration", "Intersection", "AnyHit", + "ClosestHit", "Miss", "Callable", "Mesh", "Amplification"] + >]; + let Documentation = [Undocumented]; +} diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index c5171359e7e4c..5e31a027415dd 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11575,6 +11575,8 @@ def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in % def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">; def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">; +def err_hlsl_invalid_attribute_argument : Error< + "%0 attribute argument not supported: %1">; } // end of sema component. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 5ec03391e287b..5c8d703d9a03a 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -3474,6 +3474,8 @@ class Sema final { HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z); + HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLShaderAttr::EnvironmentType Stage); void mergeDeclAttributes(NamedDecl *New, Decl *Old, AvailabilityMergeKind AMK = AMK_Redeclaration); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index a2d3722f2efb8..38adf38e8af5b 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -2789,6 +2789,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D, else if (const auto *NT = dyn_cast(Attr)) NewAttr = S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ()); + else if (const auto *SA = dyn_cast(Attr)) + NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getStage()); else if (Attr->shouldInheritEvenIfAlreadyPresent() || !DeclHasAttr(D, Attr)) NewAttr = cast(Attr->clone(S.Context)); diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 4b5201db7517c..b5a86086bb976 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6910,6 +6910,44 @@ HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D, return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z); } +static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) { + if (AL.getNumArgs() != 1) { + S.Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1; + return; + } + + StringRef Str; + SourceLocation ArgLoc; + if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc)) + return; + + HLSLShaderAttr::EnvironmentType Stage; + if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, Stage)) { + S.Diag(AL.getLoc(), diag::err_hlsl_invalid_attribute_argument) + << AL << "'" + std::string(Str) + "'"; + return; + } + + // TODO: check function match the shader stage. + + HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, Stage); + if (NewAttr) + D->addAttr(NewAttr); +} + +HLSLShaderAttr* Sema::mergeHLSLShaderAttr(Decl* D, + const AttributeCommonInfo& AL, + HLSLShaderAttr::EnvironmentType Stage) { + if (HLSLShaderAttr *NT = D->getAttr()) { + if (NT->getStage() != Stage) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; + Diag(AL.getLoc(), diag::note_conflicting_attribute); + } + return nullptr; + } + return HLSLShaderAttr::Create(Context, Stage, AL); +} + static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) { if (!S.LangOpts.CPlusPlus) { S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang) @@ -8776,6 +8814,9 @@ static void ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, case ParsedAttr::AT_HLSLNumThreads: handleHLSLNumThreadsAttr(S, D, AL); break; + case ParsedAttr::AT_HLSLShader: + handleHLSLShaderAttr(S, D, AL); + break; case ParsedAttr::AT_AbiTag: handleAbiTagAttr(S, D, AL); diff --git a/clang/test/SemaHLSL/shader_attr.hlsl b/clang/test/SemaHLSL/shader_attr.hlsl new file mode 100644 index 0000000000000..66574f4ec5c50 --- /dev/null +++ b/clang/test/SemaHLSL/shader_attr.hlsl @@ -0,0 +1,76 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s -DFAIL -verify + + +#ifdef FAIL + +// expected-warning@+1 {{'shader' attribute only applies to global functions}} +[shader("compute")] +struct Fido { + // expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("pixel")] + void wag() {} + + // expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("vertex")] + static void oops() {} +}; + +// expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("vertex")] +static void oops() {} + +namespace spec { +// expected-warning@+1 {{'shader' attribute only applies to global functions}} + [shader("vertex")] +static void oops() {} +} + +// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}} +[shader("compute")] +// expected-note@+1 {{conflicting attribute is here}} +[shader("vertex")] +int doubledUp() { + return 1; +} + +// expected-note@+1 {{conflicting attribute is here}} +[shader("vertex")] +int forwardDecl(); + +// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}} +[shader("compute")] +int forwardDecl() { + return 1; +} + + +// expected-error@+1 {{'shader' attribute takes one argument}} +[shader()] +// expected-error@+1 {{'shader' attribute takes one argument}} +[shader(1,2)] +// expected-error@+1 {{'shader' attribute requires a string}} +[shader(1)] +// expected-error@+1 {{'shader' attribute argument not supported: 'cs'}} +[shader("cs")] + +#endif // END of FAIL + +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +[shader("compute")] +int entry() { + return 1; +} + +// Because these two attributes match, they should both appear in the AST +[shader("compute")] +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +int secondFn(); + +[shader("compute")] +// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} Compute +int secondFn() { + return 1; +} + +