Skip to content

Commit 6b8d1c8

Browse files
committed
[Clang] Amend SME attributes with support for ZT0.
This patch builds on top of #76971 and implements support for: * __arm_new("zt0") * __arm_in("zt0") * __arm_out("zt0") * __arm_inout("zt0") * __arm_preserves("zt0") I'll rebase this patch after I land #76971, as this is currently a stacked pull-request on top of #76971. The patch is ready for review but won't be able to land until LLVM implements support for handling ZT0 state.
1 parent fcc66ce commit 6b8d1c8

File tree

13 files changed

+185
-4
lines changed

13 files changed

+185
-4
lines changed

clang/include/clang/AST/Type.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4037,9 +4037,12 @@ class FunctionType : public Type {
40374037
// Describes the value of the state using ArmStateValue.
40384038
SME_ZAShift = 2,
40394039
SME_ZAMask = 0b111 << SME_ZAShift,
4040+
SME_ZT0Shift = 5,
4041+
SME_ZT0Mask = 0b111 << SME_ZT0Shift,
40404042

4041-
SME_AttributeMask = 0b111'111 // We only support maximum 6 bits because of
4042-
// the bitmask in FunctionTypeExtraBitfields.
4043+
SME_AttributeMask =
4044+
0b111'111'11 // We can't support more than 8 bits because of
4045+
// the bitmask in FunctionTypeExtraBitfields.
40434046
};
40444047

40454048
enum ArmStateValue : unsigned {
@@ -4054,6 +4057,10 @@ class FunctionType : public Type {
40544057
return (ArmStateValue)((AttrBits & SME_ZAMask) >> SME_ZAShift);
40554058
}
40564059

4060+
static ArmStateValue getArmZT0State(unsigned AttrBits) {
4061+
return (ArmStateValue)((AttrBits & SME_ZT0Mask) >> SME_ZT0Shift);
4062+
}
4063+
40574064
/// A simple holder for various uncommon bits which do not fit in
40584065
/// FunctionTypeBitfields. Aligned to alignof(void *) to maintain the
40594066
/// alignment of subsequent objects in TrailingObjects.
@@ -4065,7 +4072,7 @@ class FunctionType : public Type {
40654072

40664073
/// Any AArch64 SME ACLE type attributes that need to be propagated
40674074
/// on declarations and function pointers.
4068-
unsigned AArch64SMEAttributes : 6;
4075+
unsigned AArch64SMEAttributes : 8;
40694076
FunctionTypeExtraBitfields()
40704077
: NumExceptionType(0), AArch64SMEAttributes(SME_NormalFunction) {}
40714078
};
@@ -4248,7 +4255,7 @@ class FunctionProtoType final
42484255
FunctionType::ExtInfo ExtInfo;
42494256
unsigned Variadic : 1;
42504257
unsigned HasTrailingReturn : 1;
4251-
unsigned AArch64SMEAttributes : 6;
4258+
unsigned AArch64SMEAttributes : 8;
42524259
Qualifiers TypeQuals;
42534260
RefQualifierKind RefQualifier = RQ_None;
42544261
ExceptionSpecInfo ExceptionSpec;

clang/include/clang/Basic/Attr.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,6 +2535,9 @@ def ArmNew : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
25352535
bool isNewZA() const {
25362536
return llvm::is_contained(newArgs(), "za");
25372537
}
2538+
bool isNewZT0() const {
2539+
return llvm::is_contained(newArgs(), "zt0");
2540+
}
25382541
}];
25392542
}
25402543

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3692,10 +3692,14 @@ def err_sme_call_in_non_sme_target : Error<
36923692
"call to a streaming function requires 'sme'">;
36933693
def err_sme_za_call_no_za_state : Error<
36943694
"call to a shared ZA function requires the caller to have ZA state">;
3695+
def err_sme_zt0_call_no_zt0_state : Error<
3696+
"call to a shared ZT0 function requires the caller to have ZT0 state">;
36953697
def err_sme_definition_using_sm_in_non_sme_target : Error<
36963698
"function executed in streaming-SVE mode requires 'sme'">;
36973699
def err_sme_definition_using_za_in_non_sme_target : Error<
36983700
"function using ZA state requires 'sme'">;
3701+
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
3702+
"function using ZT0 state requires 'sme2'">;
36993703
def err_conflicting_attributes_arm_state : Error<
37003704
"conflicting attributes for state '%0'">;
37013705
def err_unknown_arm_state : Error<

clang/lib/AST/TypePrinter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,14 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
951951
OS << " __arm_out(\"za\")";
952952
if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
953953
OS << " __arm_inout(\"za\")";
954+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Preserves)
955+
OS << " __arm_preserves(\"zt0\")";
956+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_In)
957+
OS << " __arm_in(\"zt0\")";
958+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Out)
959+
OS << " __arm_out(\"zt0\")";
960+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_InOut)
961+
OS << " __arm_inout(\"zt0\")";
954962

955963
printFunctionAfter(Info, OS);
956964

clang/lib/CodeGen/CGCall.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,16 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
17821782
FuncAttrs.addAttribute("aarch64_pstate_za_shared");
17831783
FuncAttrs.addAttribute("aarch64_pstate_za_preserved");
17841784
}
1785+
1786+
// ZT0
1787+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Preserves)
1788+
FuncAttrs.addAttribute("aarch64_zt0_preserved");
1789+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_In)
1790+
FuncAttrs.addAttribute("aarch64_zt0_in");
1791+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Out)
1792+
FuncAttrs.addAttribute("aarch64_zt0_out");
1793+
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_InOut)
1794+
FuncAttrs.addAttribute("aarch64_zt0_inout");
17851795
}
17861796

17871797
static void AddAttributesFromAssumes(llvm::AttrBuilder &FuncAttrs,

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,6 +2381,8 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D,
23812381
if (auto *Attr = D->getAttr<ArmNewAttr>()) {
23822382
if (Attr->isNewZA())
23832383
B.addAttribute("aarch64_pstate_za_new");
2384+
if (Attr->isNewZT0())
2385+
B.addAttribute("aarch64_zt0_new");
23842386
}
23852387

23862388
// Track whether we need to add the optnone LLVM attribute,

clang/lib/Sema/SemaChecking.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7529,6 +7529,27 @@ void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
75297529
if (!CallerHasZAState)
75307530
Diag(Loc, diag::err_sme_za_call_no_za_state);
75317531
}
7532+
7533+
// If the callee uses AArch64 SME ZT0 state but the caller doesn't define
7534+
// any, then this is an error.
7535+
FunctionType::ArmStateValue ArmZT0State =
7536+
FunctionType::getArmZT0State(ExtInfo.AArch64SMEAttributes);
7537+
if (ArmZT0State != FunctionType::ARM_None) {
7538+
bool CallerHasZT0State = false;
7539+
if (const auto *CallerFD = dyn_cast<FunctionDecl>(CurContext)) {
7540+
auto *Attr = CallerFD->getAttr<ArmNewAttr>();
7541+
if (Attr && Attr->isNewZT0())
7542+
CallerHasZT0State = true;
7543+
else if (const auto *FPT =
7544+
CallerFD->getType()->getAs<FunctionProtoType>())
7545+
CallerHasZT0State = FunctionType::getArmZT0State(
7546+
FPT->getExtProtoInfo().AArch64SMEAttributes) !=
7547+
FunctionType::ARM_None;
7548+
}
7549+
7550+
if (!CallerHasZT0State)
7551+
Diag(Loc, diag::err_sme_zt0_call_no_zt0_state);
7552+
}
75327553
}
75337554

75347555
if (FDecl && FDecl->hasAttr<AllocAlignAttr>()) {

clang/lib/Sema/SemaDecl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12179,12 +12179,15 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
1217912179
const auto *Attr = NewFD->getAttr<ArmNewAttr>();
1218012180
bool UsesSM = NewFD->hasAttr<ArmLocallyStreamingAttr>();
1218112181
bool UsesZA = Attr && Attr->isNewZA();
12182+
bool UsesZT0 = Attr && Attr->isNewZT0();
1218212183
if (const auto *FPT = NewFD->getType()->getAs<FunctionProtoType>()) {
1218312184
FunctionProtoType::ExtProtoInfo EPI = FPT->getExtProtoInfo();
1218412185
UsesSM |=
1218512186
EPI.AArch64SMEAttributes & FunctionType::SME_PStateSMEnabledMask;
1218612187
UsesZA |= FunctionType::getArmZAState(EPI.AArch64SMEAttributes) !=
1218712188
FunctionType::ARM_None;
12189+
UsesZT0 |= FunctionType::getArmZT0State(EPI.AArch64SMEAttributes) !=
12190+
FunctionType::ARM_None;
1218812191
}
1218912192

1219012193
if (UsesSM || UsesZA) {
@@ -12199,6 +12202,14 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
1219912202
diag::err_sme_definition_using_za_in_non_sme_target);
1220012203
}
1220112204
}
12205+
if (UsesZT0) {
12206+
llvm::StringMap<bool> FeatureMap;
12207+
Context.getFunctionFeatureMap(FeatureMap, NewFD);
12208+
if (!FeatureMap.contains("sme2")) {
12209+
Diag(NewFD->getLocation(),
12210+
diag::err_sme_definition_using_zt0_in_non_sme2_target);
12211+
}
12212+
}
1220212213
}
1220312214

1220412215
return Redeclaration;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8843,6 +8843,7 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
88438843
}
88448844

88458845
bool HasZA = false;
8846+
bool HasZT0 = false;
88468847
for (unsigned I = 0, E = AL.getNumArgs(); I != E; ++I) {
88478848
StringRef StateName;
88488849
SourceLocation LiteralLoc;
@@ -8851,6 +8852,8 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
88518852

88528853
if (StateName == "za")
88538854
HasZA = true;
8855+
else if (StateName == "zt0")
8856+
HasZT0 = true;
88548857
else {
88558858
S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
88568859
AL.setInvalid();
@@ -8869,6 +8872,11 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
88698872
if (HasZA && ZAState != FunctionType::ARM_None &&
88708873
checkArmNewAttrMutualExclusion(S, AL, FPT, ZAState, "za"))
88718874
return;
8875+
FunctionType::ArmStateValue ZT0State =
8876+
FunctionType::getArmZT0State(FPT->getAArch64SMEAttributes());
8877+
if (HasZT0 && ZT0State != FunctionType::ARM_None &&
8878+
checkArmNewAttrMutualExclusion(S, AL, FPT, ZT0State, "zt0"))
8879+
return;
88728880
}
88738881

88748882
D->dropAttr<ArmNewAttr>();

clang/lib/Sema/SemaType.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7899,6 +7899,9 @@ static bool handleArmStateAttribute(Sema &S,
78997899
if (StateName == "za") {
79007900
Shift = FunctionType::SME_ZAShift;
79017901
ExistingState = FunctionType::getArmZAState(EPI.AArch64SMEAttributes);
7902+
} else if (StateName == "zt0") {
7903+
Shift = FunctionType::SME_ZT0Shift;
7904+
ExistingState = FunctionType::getArmZT0State(EPI.AArch64SMEAttributes);
79027905
} else {
79037906
S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
79047907
Attr.setInvalid();

0 commit comments

Comments
 (0)