Skip to content

[AArch64][SME2] Preserve ZT0 state around function calls #78321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::SMSTART)
MAKE_CASE(AArch64ISD::SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
MAKE_CASE(AArch64ISD::RESTORE_ZT)
MAKE_CASE(AArch64ISD::SAVE_ZT)
MAKE_CASE(AArch64ISD::CALL)
MAKE_CASE(AArch64ISD::ADRP)
MAKE_CASE(AArch64ISD::ADR)
Expand Down Expand Up @@ -7654,6 +7656,34 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
});
}

SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);

// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
if (ShouldPreserveZT0) {
unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
ZTFrameIdx = DAG.getFrameIndex(
ZTObj,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));

Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
{Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
}

// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");

if (DisableZA)
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));

// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
if (!IsSibCall)
Expand Down Expand Up @@ -8065,13 +8095,19 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Result, InGlue, PStateSM, false);
}

if (RequiresLazySave) {
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));

if (ShouldPreserveZT0)
Result =
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});

if (RequiresLazySave) {
// Conditionally restore the lazy save using a pseudo node.
unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
Expand Down Expand Up @@ -8100,7 +8136,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(0, DL, MVT::i64));
}

if (RequiresSMChange || RequiresLazySave) {
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
Expand Down Expand Up @@ -23977,6 +24013,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getMergeValues(
{A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
}
case Intrinsic::aarch64_sme_ldr_zt:
return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
DAG.getVTList(MVT::Other), N->getOperand(0),
N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sme_str_zt:
return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
DAG.getVTList(MVT::Other), N->getOperand(0),
N->getOperand(2), N->getOperand(3));
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ enum NodeType : unsigned {
SMSTART,
SMSTOP,
RESTORE_ZA,
RESTORE_ZT,
SAVE_ZT,

// Produces the full sequence of instructions for getting the thread pointer
// offset of a variable into X0, using the TLSDesc model.
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue]>;
def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
Expand Down Expand Up @@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops

defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;

defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;

def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ class SMEAttrs {
State == StateValue::InOut || State == StateValue::Preserved;
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
}
bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
}
};

} // namespace llvm
Expand Down
155 changes: 155 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-zt0-state.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s

declare void @callee();

;
; Private-ZA Callee
;

; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_in_caller_no_state_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee();
ret void;
}

; Expect spill & fill of ZT0 around call
; Expect setup and restore lazy-save around call
; Expect smstart za after call
define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: sub x19, x29, #80
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: sub x0, x29, #16
; CHECK-NEXT: cbnz x8, .LBB1_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: bl __arm_tpidr2_restore
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee();
ret void;
}

;
; Shared-ZA Callee
;

; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: bl callee
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_in_zt0";
ret void;
}

; Expect spill & fill of ZT0 around call
define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: sub x19, x29, #80
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl callee
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_pstate_za_shared";
ret void;
}

; Caller and callee have shared ZA & ZT0
define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: bl callee
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
ret void;
}

; New-ZA Callee

; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee() "aarch64_new_zt0";
ret void;
}
36 changes: 36 additions & 0 deletions llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ TEST(SMEAttributes, Basics) {
TEST(SMEAttributes, Transitions) {
// Normal -> Normal
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
// Normal -> Normal + LocallyStreaming
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));

Expand Down Expand Up @@ -240,4 +243,37 @@ TEST(SMEAttributes, Transitions) {
// Streaming-compatible -> Streaming-compatible + LocallyStreaming
ASSERT_FALSE(SA(SA::SM_Compatible)
.requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));

SA Private_ZA = SA(SA::Normal);
SA ZA_Shared = SA(SA::ZA_Shared);
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
SA ZA_ZT0_Shared = SA(SA::ZA_Shared | SA::encodeZT0State(SA::StateValue::In));

// Shared ZA -> Private ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));

// Shared ZT0 -> Private ZA Interface
ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));

// Shared ZA & ZT0 -> Private ZA Interface
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));

// Shared ZA -> Shared ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));

// Shared ZT0 -> Shared ZA Interface
ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));

// Shared ZA & ZT0 -> Shared ZA Interface
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
}