Skip to content

Commit c5509fe

Browse files
authored
[HLSL] Support packoffset attribute in AST (#89836)
Add HLSLPackOffsetAttr to save packoffset in AST. Since we have to parse the attribute manually in ParseHLSLAnnotations, we could create the ParsedAttribute with a integer offset parameter instead of string. This approach avoids parsing the string if the offset is saved as a string in HLSLPackOffsetAttr. For #57914.
1 parent 943617d commit c5509fe

File tree

10 files changed

+484
-0
lines changed

10 files changed

+484
-0
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4415,6 +4415,18 @@ def HLSLResourceBinding: InheritableAttr {
44154415
let Documentation = [HLSLResourceBindingDocs];
44164416
}
44174417

4418+
def HLSLPackOffset: HLSLAnnotationAttr {
4419+
let Spellings = [HLSLAnnotation<"packoffset">];
4420+
let LangOpts = [HLSL];
4421+
let Args = [IntArgument<"Subcomponent">, IntArgument<"Component">];
4422+
let Documentation = [HLSLPackOffsetDocs];
4423+
let AdditionalMembers = [{
4424+
unsigned getOffset() {
4425+
return subcomponent * 4 + component;
4426+
}
4427+
}];
4428+
}
4429+
44184430
def HLSLSV_DispatchThreadID: HLSLAnnotationAttr {
44194431
let Spellings = [HLSLAnnotation<"SV_DispatchThreadID">];
44204432
let Subjects = SubjectList<[ParmVar, Field]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7408,6 +7408,26 @@ The full documentation is available here: https://docs.microsoft.com/en-us/windo
74087408
}];
74097409
}
74107410

7411+
def HLSLPackOffsetDocs : Documentation {
7412+
let Category = DocCatFunction;
7413+
let Content = [{
7414+
The packoffset attribute is used to change the layout of a cbuffer.
7415+
Attribute spelling in HLSL is: ``packoffset( c[Subcomponent][.component] )``.
7416+
A subcomponent is a register number, which is an integer. A component is in the form of [.xyzw].
7417+
7418+
Examples:
7419+
7420+
.. code-block:: c++
7421+
7422+
cbuffer A {
7423+
float3 a : packoffset(c0.y);
7424+
float4 b : packoffset(c4);
7425+
}
7426+
7427+
The full documentation is available here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-packoffset
7428+
}];
7429+
}
7430+
74117431
def HLSLSV_DispatchThreadIDDocs : Documentation {
74127432
let Category = DocCatFunction;
74137433
let Content = [{

clang/include/clang/Basic/DiagnosticGroups.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,6 +1507,9 @@ def BranchProtection : DiagGroup<"branch-protection">;
15071507
// Warnings for HLSL Clang extensions
15081508
def HLSLExtension : DiagGroup<"hlsl-extensions">;
15091509

1510+
// Warning for mix packoffset and non-packoffset.
1511+
def HLSLMixPackOffset : DiagGroup<"mix-packoffset">;
1512+
15101513
// Warnings for DXIL validation
15111514
def DXILValidation : DiagGroup<"dxil-validation">;
15121515

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,5 +1754,7 @@ def err_hlsl_separate_attr_arg_and_number : Error<"wrong argument format for hls
17541754
def ext_hlsl_access_specifiers : ExtWarn<
17551755
"access specifiers are a clang HLSL extension">,
17561756
InGroup<HLSLExtension>;
1757+
def err_hlsl_unsupported_component : Error<"invalid component '%0' used; expected 'x', 'y', 'z', or 'w'">;
1758+
def err_hlsl_packoffset_invalid_reg : Error<"invalid resource class specifier '%0' for packoffset, expected 'c'">;
17571759

17581760
} // end of Parser diagnostics

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12184,6 +12184,11 @@ def err_hlsl_init_priority_unsupported : Error<
1218412184
def err_hlsl_unsupported_register_type : Error<"invalid resource class specifier '%0' used; expected 'b', 's', 't', or 'u'">;
1218512185
def err_hlsl_unsupported_register_number : Error<"register number should be an integer">;
1218612186
def err_hlsl_expected_space : Error<"invalid space specifier '%0' used; expected 'space' followed by an integer, like space1">;
12187+
def warn_hlsl_packoffset_mix : Warning<"cannot mix packoffset elements with nonpackoffset elements in a cbuffer">,
12188+
InGroup<HLSLMixPackOffset>;
12189+
def err_hlsl_packoffset_overlap : Error<"packoffset overlap between %0, %1">;
12190+
def err_hlsl_packoffset_cross_reg_boundary : Error<"packoffset cannot cross register boundary">;
12191+
def err_hlsl_packoffset_alignment_mismatch : Error<"packoffset at 'y' not match alignment %0 required by %1">;
1218712192
def err_hlsl_pointers_unsupported : Error<
1218812193
"%select{pointers|references}0 are unsupported in HLSL">;
1218912194

clang/lib/Parse/ParseHLSL.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,94 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
183183
return;
184184
}
185185
} break;
186+
case ParsedAttr::AT_HLSLPackOffset: {
187+
// Parse 'packoffset( c[Subcomponent][.component] )'.
188+
// Check '('.
189+
if (ExpectAndConsume(tok::l_paren, diag::err_expected_lparen_after)) {
190+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
191+
return;
192+
}
193+
// Check c[Subcomponent] as an identifier.
194+
if (!Tok.is(tok::identifier)) {
195+
Diag(Tok.getLocation(), diag::err_expected) << tok::identifier;
196+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
197+
return;
198+
}
199+
StringRef OffsetStr = Tok.getIdentifierInfo()->getName();
200+
SourceLocation SubComponentLoc = Tok.getLocation();
201+
if (OffsetStr[0] != 'c') {
202+
Diag(Tok.getLocation(), diag::err_hlsl_packoffset_invalid_reg)
203+
<< OffsetStr;
204+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
205+
return;
206+
}
207+
OffsetStr = OffsetStr.substr(1);
208+
unsigned SubComponent = 0;
209+
if (!OffsetStr.empty()) {
210+
// Make sure SubComponent is a number.
211+
if (OffsetStr.getAsInteger(10, SubComponent)) {
212+
Diag(SubComponentLoc.getLocWithOffset(1),
213+
diag::err_hlsl_unsupported_register_number);
214+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
215+
return;
216+
}
217+
}
218+
unsigned Component = 0;
219+
ConsumeToken(); // consume identifier.
220+
SourceLocation ComponentLoc;
221+
if (Tok.is(tok::period)) {
222+
ConsumeToken(); // consume period.
223+
if (!Tok.is(tok::identifier)) {
224+
Diag(Tok.getLocation(), diag::err_expected) << tok::identifier;
225+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
226+
return;
227+
}
228+
StringRef ComponentStr = Tok.getIdentifierInfo()->getName();
229+
ComponentLoc = Tok.getLocation();
230+
ConsumeToken(); // consume identifier.
231+
// Make sure Component is a single character.
232+
if (ComponentStr.size() != 1) {
233+
Diag(ComponentLoc, diag::err_hlsl_unsupported_component)
234+
<< ComponentStr;
235+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
236+
return;
237+
}
238+
switch (ComponentStr[0]) {
239+
case 'x':
240+
case 'r':
241+
Component = 0;
242+
break;
243+
case 'y':
244+
case 'g':
245+
Component = 1;
246+
break;
247+
case 'z':
248+
case 'b':
249+
Component = 2;
250+
break;
251+
case 'w':
252+
case 'a':
253+
Component = 3;
254+
break;
255+
default:
256+
Diag(ComponentLoc, diag::err_hlsl_unsupported_component)
257+
<< ComponentStr;
258+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
259+
return;
260+
}
261+
}
262+
ASTContext &Ctx = Actions.getASTContext();
263+
QualType SizeTy = Ctx.getSizeType();
264+
uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
265+
ArgExprs.push_back(IntegerLiteral::Create(
266+
Ctx, llvm::APInt(SizeTySize, SubComponent), SizeTy, SubComponentLoc));
267+
ArgExprs.push_back(IntegerLiteral::Create(
268+
Ctx, llvm::APInt(SizeTySize, Component), SizeTy, ComponentLoc));
269+
if (ExpectAndConsume(tok::r_paren, diag::err_expected)) {
270+
SkipUntil(tok::r_paren, StopAtSemi); // skip through )
271+
return;
272+
}
273+
} break;
186274
case ParsedAttr::UnknownAttribute:
187275
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
188276
return;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7309,6 +7309,55 @@ static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D,
73097309
D->addAttr(::new (S.Context) HLSLSV_DispatchThreadIDAttr(S.Context, AL));
73107310
}
73117311

7312+
static void handleHLSLPackOffsetAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
7313+
if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
7314+
S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
7315+
<< AL << "shader constant in a constant buffer";
7316+
return;
7317+
}
7318+
7319+
uint32_t SubComponent;
7320+
if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(0), SubComponent))
7321+
return;
7322+
uint32_t Component;
7323+
if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(1), Component))
7324+
return;
7325+
7326+
QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
7327+
// Check if T is an array or struct type.
7328+
// TODO: mark matrix type as aggregate type.
7329+
bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
7330+
7331+
// Check Component is valid for T.
7332+
if (Component) {
7333+
unsigned Size = S.getASTContext().getTypeSize(T);
7334+
if (IsAggregateTy || Size > 128) {
7335+
S.Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
7336+
return;
7337+
} else {
7338+
// Make sure Component + sizeof(T) <= 4.
7339+
if ((Component * 32 + Size) > 128) {
7340+
S.Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
7341+
return;
7342+
}
7343+
QualType EltTy = T;
7344+
if (const auto *VT = T->getAs<VectorType>())
7345+
EltTy = VT->getElementType();
7346+
unsigned Align = S.getASTContext().getTypeAlign(EltTy);
7347+
if (Align > 32 && Component == 1) {
7348+
// NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
7349+
// So we only need to check Component 1 here.
7350+
S.Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
7351+
<< Align << EltTy;
7352+
return;
7353+
}
7354+
}
7355+
}
7356+
7357+
D->addAttr(::new (S.Context)
7358+
HLSLPackOffsetAttr(S.Context, AL, SubComponent, Component));
7359+
}
7360+
73127361
static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
73137362
StringRef Str;
73147363
SourceLocation ArgLoc;
@@ -9730,6 +9779,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
97309779
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
97319780
handleHLSLSV_DispatchThreadIDAttr(S, D, AL);
97329781
break;
9782+
case ParsedAttr::AT_HLSLPackOffset:
9783+
handleHLSLPackOffsetAttr(S, D, AL);
9784+
break;
97339785
case ParsedAttr::AT_HLSLShader:
97349786
handleHLSLShaderAttr(S, D, AL);
97359787
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,89 @@ Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
3939
return Result;
4040
}
4141

42+
// Calculate the size of a legacy cbuffer type based on
43+
// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
44+
static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
45+
QualType T) {
46+
unsigned Size = 0;
47+
constexpr unsigned CBufferAlign = 128;
48+
if (const RecordType *RT = T->getAs<RecordType>()) {
49+
const RecordDecl *RD = RT->getDecl();
50+
for (const FieldDecl *Field : RD->fields()) {
51+
QualType Ty = Field->getType();
52+
unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
53+
unsigned FieldAlign = 32;
54+
if (Ty->isAggregateType())
55+
FieldAlign = CBufferAlign;
56+
Size = llvm::alignTo(Size, FieldAlign);
57+
Size += FieldSize;
58+
}
59+
} else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
60+
if (unsigned ElementCount = AT->getSize().getZExtValue()) {
61+
unsigned ElementSize =
62+
calculateLegacyCbufferSize(Context, AT->getElementType());
63+
unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
64+
Size = AlignedElementSize * (ElementCount - 1) + ElementSize;
65+
}
66+
} else if (const VectorType *VT = T->getAs<VectorType>()) {
67+
unsigned ElementCount = VT->getNumElements();
68+
unsigned ElementSize =
69+
calculateLegacyCbufferSize(Context, VT->getElementType());
70+
Size = ElementSize * ElementCount;
71+
} else {
72+
Size = Context.getTypeSize(T);
73+
}
74+
return Size;
75+
}
76+
4277
void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
4378
auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
4479
BufDecl->setRBraceLoc(RBrace);
80+
81+
// Validate packoffset.
82+
llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
83+
bool HasPackOffset = false;
84+
bool HasNonPackOffset = false;
85+
for (auto *Field : BufDecl->decls()) {
86+
VarDecl *Var = dyn_cast<VarDecl>(Field);
87+
if (!Var)
88+
continue;
89+
if (Field->hasAttr<HLSLPackOffsetAttr>()) {
90+
PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
91+
HasPackOffset = true;
92+
} else {
93+
HasNonPackOffset = true;
94+
}
95+
}
96+
97+
if (HasPackOffset && HasNonPackOffset)
98+
Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
99+
100+
if (HasPackOffset) {
101+
ASTContext &Context = getASTContext();
102+
// Make sure no overlap in packoffset.
103+
// Sort PackOffsetVec by offset.
104+
std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
105+
[](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
106+
const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
107+
return LHS.second->getOffset() < RHS.second->getOffset();
108+
});
109+
110+
for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
111+
VarDecl *Var = PackOffsetVec[i].first;
112+
HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
113+
unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
114+
unsigned Begin = Attr->getOffset() * 32;
115+
unsigned End = Begin + Size;
116+
unsigned NextBegin = PackOffsetVec[i + 1].second->getOffset() * 32;
117+
if (End > NextBegin) {
118+
VarDecl *NextVar = PackOffsetVec[i + 1].first;
119+
Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
120+
<< NextVar << Var;
121+
}
122+
}
123+
}
124+
45125
SemaRef.PopDeclContext();
46126
}
47127

0 commit comments

Comments
 (0)