Skip to content

[AArch64][SME2] Preserve ZT0 state around function calls #76968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5176,7 +5176,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
const TargetLibraryInfo *LibInfo) {

SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasStreamingInterfaceOrBody() ||
if (CallerAttrs.hasZAState() || CallerAttrs.hasZTState() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
Expand Down
31 changes: 30 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::SMSTART)
MAKE_CASE(AArch64ISD::SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
MAKE_CASE(AArch64ISD::RESTORE_ZT)
MAKE_CASE(AArch64ISD::SAVE_ZT)
MAKE_CASE(AArch64ISD::CALL)
MAKE_CASE(AArch64ISD::ADRP)
MAKE_CASE(AArch64ISD::ADR)
Expand Down Expand Up @@ -7659,6 +7661,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
});
}

SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
bool PreserveZT = CallerAttrs.requiresPreservingZT(CalleeAttrs);

if (PreserveZT) {
unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
ZTFrameIdx = DAG.getFrameIndex(
ZTObj,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));

Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
{Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
}

// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
if (!IsSibCall)
Expand Down Expand Up @@ -8077,6 +8093,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));

if (PreserveZT)
Result =
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});

// Conditionally restore the lazy save using a pseudo node.
unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
Expand Down Expand Up @@ -8105,7 +8126,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(0, DL, MVT::i64));
}

if (RequiresSMChange || RequiresLazySave) {
if (RequiresSMChange || RequiresLazySave || PreserveZT) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
Expand Down Expand Up @@ -23953,6 +23974,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getMergeValues(
{A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
}
case Intrinsic::aarch64_sme_ldr_zt:
return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
DAG.getVTList(MVT::Other), N->getOperand(0),
N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sme_str_zt:
return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
DAG.getVTList(MVT::Other), N->getOperand(0),
N->getOperand(2), N->getOperand(3));
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ enum NodeType : unsigned {
SMSTART,
SMSTOP,
RESTORE_ZA,
RESTORE_ZT,
SAVE_ZT,

// Produces the full sequence of instructions for getting the thread pointer
// offset of a variable into X0, using the TLSDesc model.
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue]>;
def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
Expand Down Expand Up @@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops

defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;

defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;

def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
Expand Down
18 changes: 14 additions & 4 deletions llvm/lib/Target/AArch64/SMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
bool runOnFunction(Function &F) override;

private:
bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
bool ClearZTState);
};
} // end anonymous namespace

Expand Down Expand Up @@ -82,8 +83,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
/// is active and we should call __arm_tpidr2_save to commit the lazy save.
/// Additionally, PSTATE.ZA should be enabled at the beginning of the function
/// and disabled before returning.
bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
IRBuilder<> &Builder) {
bool SMEABI::updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
bool ClearZTState) {
LLVMContext &Context = F->getContext();
BasicBlock *OrigBB = &F->getEntryBlock();

Expand Down Expand Up @@ -117,6 +118,14 @@ bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
Builder.getInt32(0xff));

// Clear ZT0 on entry to the function if required, after enabling pstate.za
if (ClearZTState) {
Function *ClearZT0Intr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
{Builder.getInt32(0)});
}

// Before returning, disable pstate.za
for (BasicBlock &BB : *F) {
Instruction *T = BB.getTerminator();
Expand All @@ -143,7 +152,8 @@ bool SMEABI::runOnFunction(Function &F) {
bool Changed = false;
SMEAttrs FnAttrs(F);
if (FnAttrs.hasNewZABody())
Changed |= updateNewZAFunctions(M, &F, Builder);
Changed |= updateNewZAFunctions(M, &F, Builder,
FnAttrs.requiresPreservingZT(SMEAttrs()));

return Changed;
}
17 changes: 15 additions & 2 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
else
Bitmask &= ~M;

// Streaming Mode Attrs
assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) &&
"SM_Enabled and SM_Compatible are mutually exclusive");
// ZA Attrs
assert(!(hasNewZABody() && hasSharedZAInterface()) &&
"ZA_New and ZA_Shared are mutually exclusive");
assert(!(hasNewZABody() && preservesZA()) &&
Expand All @@ -28,6 +30,11 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"ZA_New and ZA_NoLazySave are mutually exclusive");
assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) &&
"ZA_Shared and ZA_NoLazySave are mutually exclusive");
// ZT Attrs
assert(!(hasNewZTBody() && hasSharedZTInterface()) &&
"ZT_New and ZT_Shared are mutually exclusive");
assert(!(hasNewZTBody() && preservesZT()) &&
"ZT_New and ZT_Preserved are mutually exclusive");
}

SMEAttrs::SMEAttrs(const CallBase &CB) {
Expand All @@ -40,10 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
SMEAttrs::ZA_NoLazySave);
SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Preserved);
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
SMEAttrs::ZA_NoLazySave);
SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Shared);
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand All @@ -60,6 +67,12 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= ZA_New;
if (Attrs.hasFnAttr("aarch64_pstate_za_preserved"))
Bitmask |= ZA_Preserved;
if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_shared"))
Bitmask |= ZT_Shared;
if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_new"))
Bitmask |= ZT_New;
if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_preserved"))
Bitmask |= ZT_Preserved;
}

std::optional<bool>
Expand Down
20 changes: 19 additions & 1 deletion llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class SMEAttrs {
ZA_New = 1 << 4, // aarch64_pstate_sm_new
ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved
ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
All = ZA_Preserved - 1
ZT_New = 1 << 7, // aarch64_sme_pstate_zt0_new
ZT_Shared = 1 << 8, // aarch64_sme_pstate_zt0_shared
ZT_Preserved = 1 << 9, // aarch64_sme_pstate_zt0_preserved
All = ZT_Preserved - 1
};

SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
Expand Down Expand Up @@ -74,6 +77,14 @@ class SMEAttrs {
requiresSMChange(const SMEAttrs &Callee,
bool BodyOverridesInterface = false) const;

/// \return true if a call from Caller -> Callee requires ZT0 state to be
/// preserved.
/// ZT0 must be preserved if the caller has ZT state and the callee
/// does not preserve ZT.
bool requiresPreservingZT(const SMEAttrs &Callee) const {
return hasZTState() && !Callee.preservesZT();
}

// Interfaces to query PSTATE.ZA
bool hasNewZABody() const { return Bitmask & ZA_New; }
bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; }
Expand All @@ -82,6 +93,13 @@ class SMEAttrs {
bool hasZAState() const {
return hasNewZABody() || hasSharedZAInterface();
}

// Interfaces to query ZT0 state
bool hasNewZTBody() const { return Bitmask & ZT_New; }
bool hasSharedZTInterface() const { return Bitmask & ZT_Shared; }
bool preservesZT() const { return Bitmask & ZT_Preserved; }
bool hasZTState() const { return hasNewZTBody() || hasSharedZTInterface(); }

bool requiresLazySave(const SMEAttrs &Callee) const {
return hasZAState() && Callee.hasPrivateZAInterface() &&
!(Callee.Bitmask & ZA_NoLazySave);
Expand Down
Loading