diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp index e98f6c4984a75..f63cdf8bc4f32 100644 --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -5176,7 +5176,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo, const TargetLibraryInfo *LibInfo) { SMEAttrs CallerAttrs(*FuncInfo.Fn); - if (CallerAttrs.hasZAState() || CallerAttrs.hasStreamingInterfaceOrBody() || + if (CallerAttrs.hasZAState() || CallerAttrs.hasZTState() || + CallerAttrs.hasStreamingInterfaceOrBody() || CallerAttrs.hasStreamingCompatibleInterface()) return nullptr; return new AArch64FastISel(FuncInfo, LibInfo); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 102fd0c3dae2a..4121621616b8b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2338,6 +2338,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::SMSTART) MAKE_CASE(AArch64ISD::SMSTOP) MAKE_CASE(AArch64ISD::RESTORE_ZA) + MAKE_CASE(AArch64ISD::RESTORE_ZT) + MAKE_CASE(AArch64ISD::SAVE_ZT) MAKE_CASE(AArch64ISD::CALL) MAKE_CASE(AArch64ISD::ADRP) MAKE_CASE(AArch64ISD::ADR) @@ -7659,6 +7661,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, }); } + SDValue ZTFrameIdx; + MachineFrameInfo &MFI = MF.getFrameInfo(); + bool PreserveZT = CallerAttrs.requiresPreservingZT(CalleeAttrs); + + if (PreserveZT) { + unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16)); + ZTFrameIdx = DAG.getFrameIndex( + ZTObj, + DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + + Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other), + {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx}); + } + // Adjust the stack pointer for the new arguments... // These operations are automatically eliminated by the prolog/epilog pass if (!IsSibCall) @@ -8077,6 +8093,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64)); + if (PreserveZT) + Result = + DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other), + {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx}); + // Conditionally restore the lazy save using a pseudo node. unsigned FI = FuncInfo->getLazySaveTPIDR2Obj(); SDValue RegMask = DAG.getRegisterMask( @@ -8105,7 +8126,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getConstant(0, DL, MVT::i64)); } - if (RequiresSMChange || RequiresLazySave) { + if (RequiresSMChange || RequiresLazySave || PreserveZT) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the // resulting chain is discarded (which happens when the call is not part @@ -23953,6 +23974,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return DAG.getMergeValues( {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); } + case Intrinsic::aarch64_sme_ldr_zt: + return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N), + DAG.getVTList(MVT::Other), N->getOperand(0), + N->getOperand(2), N->getOperand(3)); + case Intrinsic::aarch64_sme_str_zt: + return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N), + DAG.getVTList(MVT::Other), N->getOperand(0), + N->getOperand(2), N->getOperand(3)); default: break; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 6ddbcd41dcb76..6c14bc0aa8dc7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -61,6 +61,8 @@ enum NodeType : unsigned { SMSTART, SMSTOP, RESTORE_ZA, + RESTORE_ZT, + SAVE_ZT, // Produces the full sequence of instructions for getting the thread pointer // offset of a variable into X0, using the TLSDesc model. diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 380f6e1fcfdae..eeae5303a3f89 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>]>, [SDNPHasChain, SDNPSideEffect, SDNPVariadic, SDNPOptInGlue]>; +def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2, + [SDTCisInt<0>, SDTCisPtrTy<1>]>, + [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>; +def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2, + [SDTCisInt<0>, SDTCisPtrTy<1>]>, + [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>; //===----------------------------------------------------------------------===// // Instruction naming conventions. @@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops defm ZERO_T : sme2_zero_zt<"zero", 0b0001>; -defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>; -defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>; +defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>; +defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>; def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>; def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>; diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index 3315171798d9f..4ca0cf648bc14 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass { bool runOnFunction(Function &F) override; private: - bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder); + bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder, + bool ClearZTState); }; } // end anonymous namespace @@ -82,8 +83,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) { /// is active and we should call __arm_tpidr2_save to commit the lazy save. /// Additionally, PSTATE.ZA should be enabled at the beginning of the function /// and disabled before returning. -bool SMEABI::updateNewZAFunctions(Module *M, Function *F, - IRBuilder<> &Builder) { +bool SMEABI::updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder, + bool ClearZTState) { LLVMContext &Context = F->getContext(); BasicBlock *OrigBB = &F->getEntryBlock(); @@ -117,6 +118,14 @@ bool SMEABI::updateNewZAFunctions(Module *M, Function *F, Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr, Builder.getInt32(0xff)); + // Clear ZT0 on entry to the function if required, after enabling pstate.za + if (ClearZTState) { + Function *ClearZT0Intr = + Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt); + Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr, + {Builder.getInt32(0)}); + } + // Before returning, disable pstate.za for (BasicBlock &BB : *F) { Instruction *T = BB.getTerminator(); @@ -143,7 +152,8 @@ bool SMEABI::runOnFunction(Function &F) { bool Changed = false; SMEAttrs FnAttrs(F); if (FnAttrs.hasNewZABody()) - Changed |= updateNewZAFunctions(M, &F, Builder); + Changed |= updateNewZAFunctions(M, &F, Builder, + FnAttrs.requiresPreservingZT(SMEAttrs())); return Changed; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 0082b4017986c..ef3a043a15bcc 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -18,8 +18,10 @@ void SMEAttrs::set(unsigned M, bool Enable) { else Bitmask &= ~M; + // Streaming Mode Attrs assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) && "SM_Enabled and SM_Compatible are mutually exclusive"); + // ZA Attrs assert(!(hasNewZABody() && hasSharedZAInterface()) && "ZA_New and ZA_Shared are mutually exclusive"); assert(!(hasNewZABody() && preservesZA()) && @@ -28,6 +30,11 @@ void SMEAttrs::set(unsigned M, bool Enable) { "ZA_New and ZA_NoLazySave are mutually exclusive"); assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) && "ZA_Shared and ZA_NoLazySave are mutually exclusive"); + // ZT Attrs + assert(!(hasNewZTBody() && hasSharedZTInterface()) && + "ZT_New and ZT_Shared are mutually exclusive"); + assert(!(hasNewZTBody() && preservesZT()) && + "ZT_New and ZT_Preserved are mutually exclusive"); } SMEAttrs::SMEAttrs(const CallBase &CB) { @@ -40,10 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) { SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) { if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state") Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved | - SMEAttrs::ZA_NoLazySave); + SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Preserved); if (FuncName == "__arm_tpidr2_restore") Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared | - SMEAttrs::ZA_NoLazySave); + SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Shared); } SMEAttrs::SMEAttrs(const AttributeList &Attrs) { @@ -60,6 +67,12 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= ZA_New; if (Attrs.hasFnAttr("aarch64_pstate_za_preserved")) Bitmask |= ZA_Preserved; + if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_shared")) + Bitmask |= ZT_Shared; + if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_new")) + Bitmask |= ZT_New; + if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_preserved")) + Bitmask |= ZT_Preserved; } std::optional diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index e766b778b5410..3eceaf95a249a 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -36,7 +36,10 @@ class SMEAttrs { ZA_New = 1 << 4, // aarch64_pstate_sm_new ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves - All = ZA_Preserved - 1 + ZT_New = 1 << 7, // aarch64_sme_pstate_zt0_new + ZT_Shared = 1 << 8, // aarch64_sme_pstate_zt0_shared + ZT_Preserved = 1 << 9, // aarch64_sme_pstate_zt0_preserved + All = ZT_Preserved - 1 }; SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); } @@ -74,6 +77,14 @@ class SMEAttrs { requiresSMChange(const SMEAttrs &Callee, bool BodyOverridesInterface = false) const; + /// \return true if a call from Caller -> Callee requires ZT0 state to be + /// preserved. + /// ZT0 must be preserved if the caller has ZT state and the callee + /// does not preserve ZT. + bool requiresPreservingZT(const SMEAttrs &Callee) const { + return hasZTState() && !Callee.preservesZT(); + } + // Interfaces to query PSTATE.ZA bool hasNewZABody() const { return Bitmask & ZA_New; } bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; } @@ -82,6 +93,13 @@ class SMEAttrs { bool hasZAState() const { return hasNewZABody() || hasSharedZAInterface(); } + + // Interfaces to query ZT0 state + bool hasNewZTBody() const { return Bitmask & ZT_New; } + bool hasSharedZTInterface() const { return Bitmask & ZT_Shared; } + bool preservesZT() const { return Bitmask & ZT_Preserved; } + bool hasZTState() const { return hasNewZTBody() || hasSharedZTInterface(); } + bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && !(Callee.Bitmask & ZA_NoLazySave); diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll new file mode 100644 index 0000000000000..bbcfd5cac197b --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll @@ -0,0 +1,306 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s + +; Normal callee, no ZT state +declare void @normal_callee(); + +; Callees with ZT state +declare void @za_shared_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared"; +declare void @za_new_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new"; + +; Callee with preserved ZT state +declare void @za_preserved_callee() "aarch64_pstate_za_preserved" "aarch64_sme_pstate_zt0_preserved"; + + +define void @za_zt_new_caller_normal_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind { +; CHECK-LABEL: za_zt_new_caller_normal_callee: +; CHECK: // %bb.0: // %prelude +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #80 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbz x8, .LBB0_2 +; CHECK-NEXT: // %bb.1: // %save.za +; CHECK-NEXT: bl __arm_tpidr2_save +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: smstart za +; CHECK-NEXT: zero {za} +; CHECK-NEXT: zero { zt0 } +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: sub x9, x29, #16 +; CHECK-NEXT: sub x19, x29, #80 +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x9 +; CHECK-NEXT: str zt0, [x19] +; CHECK-NEXT: bl normal_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: ldr zt0, [x19] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: cbnz x8, .LBB0_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB0_4: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: smstop za +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @normal_callee(); + ret void; +} + +define void @za_zt_new_caller_za_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind { +; CHECK-LABEL: za_zt_new_caller_za_callee: +; CHECK: // %bb.0: // %prelude +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #144 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbz x8, .LBB1_2 +; CHECK-NEXT: // %bb.1: // %save.za +; CHECK-NEXT: bl __arm_tpidr2_save +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: .LBB1_2: +; CHECK-NEXT: smstart za +; CHECK-NEXT: zero {za} +; CHECK-NEXT: zero { zt0 } +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: sub x9, x29, #16 +; CHECK-NEXT: sub x19, x29, #80 +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x9 +; CHECK-NEXT: str zt0, [x19] +; CHECK-NEXT: bl za_new_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: ldr zt0, [x19] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: cbnz x8, .LBB1_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB1_4: +; CHECK-NEXT: sub x8, x29, #144 +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: str zt0, [x8] +; CHECK-NEXT: bl za_shared_callee +; CHECK-NEXT: smstop za +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @za_new_callee(); + call void @za_shared_callee(); + ret void; +} + +define void @za_zt_shared_caller_normal_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind { +; CHECK-LABEL: za_zt_shared_caller_normal_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #80 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x9, x8, x8, x9 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: sub x19, x29, #80 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x9, [x29, #-16] +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: str zt0, [x19] +; CHECK-NEXT: bl normal_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: ldr zt0, [x19] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: cbnz x8, .LBB2_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB2_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @normal_callee(); + ret void; +} + +define void @za_zt_shared_caller_za_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind { +; CHECK-LABEL: za_zt_shared_caller_za_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #144 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x9, x8, x8, x9 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: sub x19, x29, #80 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x9, [x29, #-16] +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: str zt0, [x19] +; CHECK-NEXT: bl za_new_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: ldr zt0, [x19] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: cbnz x8, .LBB3_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB3_2: +; CHECK-NEXT: sub x8, x29, #144 +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: str zt0, [x8] +; CHECK-NEXT: bl za_shared_callee +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @za_new_callee(); + call void @za_shared_callee(); + ret void; +} + +define void @za_zt_new_caller_za_preserved_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind { +; CHECK-LABEL: za_zt_new_caller_za_preserved_callee: +; CHECK: // %bb.0: // %prelude +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbz x8, .LBB4_2 +; CHECK-NEXT: // %bb.1: // %save.za +; CHECK-NEXT: bl __arm_tpidr2_save +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: .LBB4_2: +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x8, x29, #16 +; CHECK-NEXT: zero {za} +; CHECK-NEXT: zero { zt0 } +; CHECK-NEXT: sturh wzr, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x8 +; CHECK-NEXT: bl za_preserved_callee +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: smstop za +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @za_preserved_callee(); + ret void; +} + +define void @za_zt_shared_caller_za_preserved_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind { +; CHECK-LABEL: za_zt_shared_caller_za_preserved_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: sub x9, x29, #16 +; CHECK-NEXT: stp x8, xzr, [x29, #-16] +; CHECK-NEXT: msr TPIDR2_EL0, x9 +; CHECK-NEXT: bl za_preserved_callee +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @za_preserved_callee(); + ret void; +} + +define void @za_zt_preserved_caller_za_callee() "aarch64_pstate_za_preserved" "aarch64_sme_pstate_zt0_preserved" nounwind { +; CHECK-LABEL: za_zt_preserved_caller_za_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl normal_callee +; CHECK-NEXT: bl za_new_callee +; CHECK-NEXT: bl za_shared_callee +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + call void @normal_callee(); + call void @za_new_callee(); + call void @za_shared_callee(); + ret void; +} + +define void @za_zt_preserved_caller_za_zt_preserved_callee() "aarch64_pstate_za_preserved" "aarch64_sme_pstate_zt0_preserved" nounwind { +; CHECK-LABEL: za_zt_preserved_caller_za_zt_preserved_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl za_preserved_callee +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + call void @za_preserved_callee(); + ret void; +} + +define i32 @spill_fill_zt_load_start_chain(ptr %ptr) "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" { +; CHECK-LABEL: spill_fill_zt_load_start_chain: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #80 +; CHECK-NEXT: .cfi_def_cfa w29, 32 +; CHECK-NEXT: .cfi_offset w19, -16 +; CHECK-NEXT: .cfi_offset w30, -24 +; CHECK-NEXT: .cfi_offset w29, -32 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur wzr, [x29, #-4] +; CHECK-NEXT: sturh wzr, [x29, #-6] +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: sub x8, x29, #80 +; CHECK-NEXT: ldr w19, [x0] +; CHECK-NEXT: str zt0, [x8] +; CHECK-NEXT: bl za_shared_callee +; CHECK-NEXT: mov w0, w19 +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + %loadval = load i32, ptr %ptr + call void @za_shared_callee() + ret i32 %loadval +} diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp index 7780c71bbc00e..653e209e9c03c 100644 --- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -50,6 +50,20 @@ TEST(SMEAttributes, Constructors) { ->getFunction("foo")) .preservesZA()); + ASSERT_TRUE( + SA(*parseIR("declare void @foo() \"aarch64_sme_pstate_zt0_shared\"") + ->getFunction("foo")) + .hasSharedZTInterface()); + + ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_pstate_zt0_new\"") + ->getFunction("foo")) + .hasNewZTBody()); + + ASSERT_TRUE( + SA(*parseIR("declare void @foo() \"aarch64_sme_pstate_zt0_preserved\"") + ->getFunction("foo")) + .preservesZT()); + // Invalid combinations. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible), "SM_Enabled and SM_Compatible are mutually exclusive"); @@ -58,6 +72,11 @@ TEST(SMEAttributes, Constructors) { EXPECT_DEBUG_DEATH(SA(SA::ZA_New | SA::ZA_Preserved), "ZA_New and ZA_Preserved are mutually exclusive"); + EXPECT_DEBUG_DEATH(SA(SA::ZT_New | SA::ZT_Shared), + "ZT_New and ZT_Shared are mutually exclusive"); + EXPECT_DEBUG_DEATH(SA(SA::ZT_New | SA::ZT_Preserved), + "ZT_New and ZT_Preserved are mutually exclusive"); + // Test that the set() methods equally check validity. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled).set(SA::SM_Compatible), "SM_Enabled and SM_Compatible are mutually exclusive"); @@ -95,6 +114,20 @@ TEST(SMEAttributes, Basics) { ASSERT_FALSE(SA(SA::Normal).hasNewZABody()); ASSERT_FALSE(SA(SA::Normal).hasZAState()); ASSERT_FALSE(SA(SA::Normal).preservesZA()); + + // Test ZT0 state interfaces + ASSERT_TRUE(SA(SA::ZT_Shared).hasSharedZTInterface()); + ASSERT_TRUE(SA(SA::ZT_Shared).hasZTState()); + ASSERT_FALSE(SA(SA::ZT_Shared).preservesZT()); + ASSERT_TRUE(SA(SA::ZT_Shared | SA::ZT_Preserved).preservesZT()); + + ASSERT_TRUE(SA(SA::ZT_New).hasNewZTBody()); + ASSERT_TRUE(SA(SA::ZT_New).hasZTState()); + ASSERT_FALSE(SA(SA::ZT_New).preservesZT()); + + ASSERT_FALSE(SA(SA::Normal).hasNewZTBody()); + ASSERT_FALSE(SA(SA::Normal).hasZTState()); + ASSERT_FALSE(SA(SA::Normal).preservesZT()); } TEST(SMEAttributes, Transitions) {