Skip to content

[Clang] Amend SME attributes with support for ZT0. #77941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -4056,10 +4056,12 @@ class FunctionType : public Type {
// Describes the value of the state using ArmStateValue.
SME_ZAShift = 2,
SME_ZAMask = 0b111 << SME_ZAShift,
SME_ZT0Shift = 5,
SME_ZT0Mask = 0b111 << SME_ZT0Shift,

SME_AttributeMask = 0b111'111 // We only support maximum 6 bits because of
// the bitmask in FunctionTypeArmAttributes
// and ExtProtoInfo.
SME_AttributeMask =
0b111'111'11 // We can't support more than 8 bits because of
// the bitmask in FunctionTypeExtraBitfields.
};

enum ArmStateValue : unsigned {
Expand All @@ -4074,13 +4076,17 @@ class FunctionType : public Type {
return (ArmStateValue)((AttrBits & SME_ZAMask) >> SME_ZAShift);
}

static ArmStateValue getArmZT0State(unsigned AttrBits) {
return (ArmStateValue)((AttrBits & SME_ZT0Mask) >> SME_ZT0Shift);
}

/// A holder for Arm type attributes as described in the Arm C/C++
/// Language extensions which are not particularly common to all
/// types and therefore accounted separately from FunctionTypeBitfields.
struct alignas(void *) FunctionTypeArmAttributes {
/// Any AArch64 SME ACLE type attributes that need to be propagated
/// on declarations and function pointers.
unsigned AArch64SMEAttributes : 6;
unsigned AArch64SMEAttributes : 8;

FunctionTypeArmAttributes() : AArch64SMEAttributes(SME_NormalFunction) {}
};
Expand Down Expand Up @@ -4266,7 +4272,7 @@ class FunctionProtoType final
FunctionType::ExtInfo ExtInfo;
unsigned Variadic : 1;
unsigned HasTrailingReturn : 1;
unsigned AArch64SMEAttributes : 6;
unsigned AArch64SMEAttributes : 8;
Qualifiers TypeQuals;
RefQualifierKind RefQualifier = RQ_None;
ExceptionSpecInfo ExceptionSpec;
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -2552,6 +2552,9 @@ def ArmNew : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
bool isNewZA() const {
return llvm::is_contained(newArgs(), "za");
}
bool isNewZT0() const {
return llvm::is_contained(newArgs(), "zt0");
}
}];
}

Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -3706,10 +3706,14 @@ def err_sme_call_in_non_sme_target : Error<
"call to a streaming function requires 'sme'">;
def err_sme_za_call_no_za_state : Error<
"call to a shared ZA function requires the caller to have ZA state">;
def err_sme_zt0_call_no_zt0_state : Error<
"call to a shared ZT0 function requires the caller to have ZT0 state">;
def err_sme_definition_using_sm_in_non_sme_target : Error<
"function executed in streaming-SVE mode requires 'sme'">;
def err_sme_definition_using_za_in_non_sme_target : Error<
"function using ZA state requires 'sme'">;
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
"function using ZT0 state requires 'sme2'">;
def err_conflicting_attributes_arm_state : Error<
"conflicting attributes for state '%0'">;
def err_unknown_arm_state : Error<
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,14 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
OS << " __arm_out(\"za\")";
if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
OS << " __arm_inout(\"za\")";
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Preserves)
OS << " __arm_preserves(\"zt0\")";
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_In)
OS << " __arm_in(\"zt0\")";
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Out)
OS << " __arm_out(\"zt0\")";
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_InOut)
OS << " __arm_inout(\"zt0\")";

printFunctionAfter(Info, OS);

Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,16 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
FuncAttrs.addAttribute("aarch64_pstate_za_shared");
FuncAttrs.addAttribute("aarch64_pstate_za_preserved");
}

// ZT0
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Preserves)
FuncAttrs.addAttribute("aarch64_preserves_zt0");
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_In)
FuncAttrs.addAttribute("aarch64_in_zt0");
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Out)
FuncAttrs.addAttribute("aarch64_out_zt0");
if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_InOut)
FuncAttrs.addAttribute("aarch64_inout_zt0");
}

static void AddAttributesFromAssumes(llvm::AttrBuilder &FuncAttrs,
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,8 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D,
if (auto *Attr = D->getAttr<ArmNewAttr>()) {
if (Attr->isNewZA())
B.addAttribute("aarch64_pstate_za_new");
if (Attr->isNewZT0())
B.addAttribute("aarch64_new_zt0");
}

// Track whether we need to add the optnone LLVM attribute,
Expand Down
22 changes: 22 additions & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7548,6 +7548,28 @@ void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
if (!CallerHasZAState)
Diag(Loc, diag::err_sme_za_call_no_za_state);
}

// If the callee uses AArch64 SME ZT0 state but the caller doesn't define
// any, then this is an error.
FunctionType::ArmStateValue ArmZT0State =
FunctionType::getArmZT0State(ExtInfo.AArch64SMEAttributes);
if (ArmZT0State != FunctionType::ARM_None) {
bool CallerHasZT0State = false;
if (const auto *CallerFD = dyn_cast<FunctionDecl>(CurContext)) {
auto *Attr = CallerFD->getAttr<ArmNewAttr>();
if (Attr && Attr->isNewZT0())
CallerHasZT0State = true;
else if (const auto *FPT =
CallerFD->getType()->getAs<FunctionProtoType>())
CallerHasZT0State =
FunctionType::getArmZT0State(
FPT->getExtProtoInfo().AArch64SMEAttributes) !=
FunctionType::ARM_None;
}

if (!CallerHasZT0State)
Diag(Loc, diag::err_sme_zt0_call_no_zt0_state);
}
}

if (FDecl && FDecl->hasAttr<AllocAlignAttr>()) {
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12234,12 +12234,15 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
const auto *Attr = NewFD->getAttr<ArmNewAttr>();
bool UsesSM = NewFD->hasAttr<ArmLocallyStreamingAttr>();
bool UsesZA = Attr && Attr->isNewZA();
bool UsesZT0 = Attr && Attr->isNewZT0();
if (const auto *FPT = NewFD->getType()->getAs<FunctionProtoType>()) {
FunctionProtoType::ExtProtoInfo EPI = FPT->getExtProtoInfo();
UsesSM |=
EPI.AArch64SMEAttributes & FunctionType::SME_PStateSMEnabledMask;
UsesZA |= FunctionType::getArmZAState(EPI.AArch64SMEAttributes) !=
FunctionType::ARM_None;
UsesZT0 |= FunctionType::getArmZT0State(EPI.AArch64SMEAttributes) !=
FunctionType::ARM_None;
}

if (UsesSM || UsesZA) {
Expand All @@ -12254,6 +12257,14 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
diag::err_sme_definition_using_za_in_non_sme_target);
}
}
if (UsesZT0) {
llvm::StringMap<bool> FeatureMap;
Context.getFunctionFeatureMap(FeatureMap, NewFD);
if (!FeatureMap.contains("sme2")) {
Diag(NewFD->getLocation(),
diag::err_sme_definition_using_zt0_in_non_sme2_target);
}
}
}

return Redeclaration;
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8994,6 +8994,7 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
}

bool HasZA = false;
bool HasZT0 = false;
for (unsigned I = 0, E = AL.getNumArgs(); I != E; ++I) {
StringRef StateName;
SourceLocation LiteralLoc;
Expand All @@ -9002,6 +9003,8 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {

if (StateName == "za")
HasZA = true;
else if (StateName == "zt0")
HasZT0 = true;
else {
S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
AL.setInvalid();
Expand All @@ -9018,6 +9021,11 @@ static void handleArmNewAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
if (HasZA && ZAState != FunctionType::ARM_None &&
checkArmNewAttrMutualExclusion(S, AL, FPT, ZAState, "za"))
return;
FunctionType::ArmStateValue ZT0State =
FunctionType::getArmZT0State(FPT->getAArch64SMEAttributes());
if (HasZT0 && ZT0State != FunctionType::ARM_None &&
checkArmNewAttrMutualExclusion(S, AL, FPT, ZT0State, "zt0"))
return;
}

D->dropAttr<ArmNewAttr>();
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7938,6 +7938,9 @@ static bool handleArmStateAttribute(Sema &S,
if (StateName == "za") {
Shift = FunctionType::SME_ZAShift;
ExistingState = FunctionType::getArmZAState(EPI.AArch64SMEAttributes);
} else if (StateName == "zt0") {
Shift = FunctionType::SME_ZT0Shift;
ExistingState = FunctionType::getArmZT0State(EPI.AArch64SMEAttributes);
} else {
S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
Attr.setInvalid();
Expand Down
57 changes: 57 additions & 0 deletions clang/test/CodeGen/aarch64-sme2-intrinsics/aarch64-sme2-attrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme2 \
// RUN: -S -disable-O0-optnone -Werror -emit-llvm -o - %s \
// RUN: | opt -S -passes=mem2reg \
// RUN: | opt -S -passes=inline \
// RUN: | FileCheck %s

// Test the attributes for ZT0 and their mappings to LLVM IR.

extern "C" {

// CHECK-LABEL: @in_zt0()
// CHECK-SAME: #[[ZT0_IN:[0-9]+]]
void in_zt0() __arm_in("zt0") { }

// CHECK-LABEL: @out_zt0()
// CHECK-SAME: #[[ZT0_OUT:[0-9]+]]
void out_zt0() __arm_out("zt0") { }

// CHECK-LABEL: @inout_zt0()
// CHECK-SAME: #[[ZT0_INOUT:[0-9]+]]
void inout_zt0() __arm_inout("zt0") { }

// CHECK-LABEL: @preserves_zt0()
// CHECK-SAME: #[[ZT0_PRESERVED:[0-9]+]]
void preserves_zt0() __arm_preserves("zt0") { }

// CHECK-LABEL: @new_zt0()
// CHECK-SAME: #[[ZT0_NEW:[0-9]+]]
__arm_new("zt0") void new_zt0() { }

// CHECK-LABEL: @in_za_zt0()
// CHECK-SAME: #[[ZA_ZT0_IN:[0-9]+]]
void in_za_zt0() __arm_in("za", "zt0") { }

// CHECK-LABEL: @out_za_zt0()
// CHECK-SAME: #[[ZA_ZT0_OUT:[0-9]+]]
void out_za_zt0() __arm_out("za", "zt0") { }

// CHECK-LABEL: @inout_za_zt0()
// CHECK-SAME: #[[ZA_ZT0_INOUT:[0-9]+]]
void inout_za_zt0() __arm_inout("za", "zt0") { }

// CHECK-LABEL: @preserves_za_zt0()
// CHECK-SAME: #[[ZA_ZT0_PRESERVED:[0-9]+]]
void preserves_za_zt0() __arm_preserves("za", "zt0") { }

// CHECK-LABEL: @new_za_zt0()
// CHECK-SAME: #[[ZA_ZT0_NEW:[0-9]+]]
__arm_new("za", "zt0") void new_za_zt0() { }

}

// CHECK: attributes #[[ZT0_IN]] = {{{.*}} "aarch64_in_zt0" {{.*}}}
// CHECK: attributes #[[ZT0_OUT]] = {{{.*}} "aarch64_out_zt0" {{.*}}}
// CHECK: attributes #[[ZT0_INOUT]] = {{{.*}} "aarch64_inout_zt0" {{.*}}}
// CHECK: attributes #[[ZT0_PRESERVED]] = {{{.*}} "aarch64_preserves_zt0" {{.*}}}
// CHECK: attributes #[[ZT0_NEW]] = {{{.*}} "aarch64_new_zt0" {{.*}}}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ void shared_za_def() __arm_inout("za") { } // expected-error {{function using ZA
__arm_new("za") void new_za_def() { } // expected-error {{function using ZA state requires 'sme'}}
__arm_locally_streaming void locally_streaming_def() { } // expected-error {{function executed in streaming-SVE mode requires 'sme'}}
void streaming_shared_za_def() __arm_streaming __arm_inout("za") { } // expected-error {{function executed in streaming-SVE mode requires 'sme'}}
void inout_za_def() __arm_inout("za") { } // expected-error {{function using ZA state requires 'sme'}}
void inout_zt0_def() __arm_inout("zt0") { } // expected-error {{function using ZT0 state requires 'sme2'}}

// It should work fine when we explicitly add the target("sme") attribute.
__attribute__((target("sme"))) void streaming_compatible_def_sme_attr() __arm_streaming_compatible {} // OK
Expand Down
45 changes: 45 additions & 0 deletions clang/test/Sema/aarch64-sme-func-attrs.c
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ void invalid_arm_in_unknown_state(void) __arm_in("unknownstate");

void valid_state_attrs_in_in1(void) __arm_in("za");
void valid_state_attrs_in_in2(void) __arm_in("za", "za");
void valid_state_attrs_in_in3(void) __arm_in("zt0");
void valid_state_attrs_in_in4(void) __arm_in("zt0", "zt0");
void valid_state_attrs_in_in5(void) __arm_in("za", "zt0");
__arm_new("za") void valid_state_attrs_in_in6(void) __arm_in("zt0");
__arm_new("zt0") void valid_state_attrs_in_in7(void) __arm_in("za");

// expected-cpp-error@+2 {{missing state for '__arm_in'}}
// expected-error@+1 {{missing state for '__arm_in'}}
Expand Down Expand Up @@ -400,3 +405,43 @@ void conflicting_state_attrs_preserves_out(void) __arm_preserves("za") __arm_out
// expected-cpp-error@+2 {{conflicting attributes for state 'za'}}
// expected-error@+1 {{conflicting attributes for state 'za'}}
void conflicting_state_attrs_preserves_inout(void) __arm_preserves("za") __arm_inout("za");

// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_in_out_zt0(void) __arm_in("zt0") __arm_out("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_in_inout_zt0(void) __arm_in("zt0") __arm_inout("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_in_preserves_zt0(void) __arm_in("zt0") __arm_preserves("zt0");

// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_out_in_zt0(void) __arm_out("zt0") __arm_in("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_out_inout_zt0(void) __arm_out("zt0") __arm_inout("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_out_preserves_zt0(void) __arm_out("zt0") __arm_preserves("zt0");

// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_inout_in_zt0(void) __arm_inout("zt0") __arm_in("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_inout_out_zt0(void) __arm_inout("zt0") __arm_out("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_inout_preserves_zt0(void) __arm_inout("zt0") __arm_preserves("zt0");

// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_preserves_in_zt0(void) __arm_preserves("zt0") __arm_in("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_preserves_out_zt0(void) __arm_preserves("zt0") __arm_out("zt0");
// expected-cpp-error@+2 {{conflicting attributes for state 'zt0'}}
// expected-error@+1 {{conflicting attributes for state 'zt0'}}
void conflicting_state_attrs_preserves_inout_zt0(void) __arm_preserves("zt0") __arm_inout("zt0");