Skip to content

[AArch64][PAC] Lower auth/resign into checked sequence. #79024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@

using namespace llvm;

enum PtrauthCheckMode { Default, Unchecked, Poison, Trap };
static cl::opt<PtrauthCheckMode> PtrauthAuthChecks(
"aarch64-ptrauth-auth-checks", cl::Hidden,
cl::values(clEnumValN(Unchecked, "none", "don't test for failure"),
clEnumValN(Poison, "poison", "poison on failure"),
clEnumValN(Trap, "trap", "trap on failure")),
cl::desc("Check pointer authentication auth/resign failures"),
cl::init(Default));

#define DEBUG_TYPE "asm-printer"

namespace {
Expand Down Expand Up @@ -130,6 +139,10 @@ class AArch64AsmPrinter : public AsmPrinter {

// Emit the sequence for BLRA (authenticate + branch).
void emitPtrauthBranch(const MachineInstr *MI);

// Emit the sequence for AUT or AUTPAC.
void emitPtrauthAuthResign(const MachineInstr *MI);

// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
unsigned &InstsEmitted);
Expand Down Expand Up @@ -1623,6 +1636,222 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
return AArch64::X17;
}

void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
unsigned InstsEmitted = 0;
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;

// We can expand AUT/AUTPAC into 3 possible sequences:
// - unchecked:
// autia x16, x0
// pacib x16, x1 ; if AUTPAC
//
// - checked and clearing:
// mov x17, x0
// movk x17, #disc, lsl #48
// autia x16, x17
// mov x17, x16
// xpaci x17
// cmp x16, x17
// b.eq Lsuccess
// mov x16, x17
// b Lend
// Lsuccess:
// mov x17, x1
// movk x17, #disc, lsl #48
// pacib x16, x17
// Lend:
// Where we only emit the AUT if we started with an AUT.
//
// - checked and trapping:
// mov x17, x0
// movk x17, #disc, lsl #48
// autia x16, x0
// mov x17, x16
// xpaci x17
// cmp x16, x17
// b.eq Lsuccess
// brk #<0xc470 + aut key>
// Lsuccess:
// mov x17, x1
// movk x17, #disc, lsl #48
// pacib x16, x17 ; if AUTPAC
// Where the b.eq skips over the trap if the PAC is valid.
//
// This sequence is expensive, but we need more information to be able to
// do better.
//
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
// on failure.
// We can't TST the PAC bits because we don't always know how the address
// space is setup for the target environment (and the bottom PAC bit is
// based on that).
// Either way, we also don't always know whether TBI is enabled or not for
// the specific target environment.

// By default, auth/resign sequences check for auth failures.
bool ShouldCheck = true;
// In the checked sequence, we only trap if explicitly requested.
bool ShouldTrap = MF->getFunction().hasFnAttribute("ptrauth-auth-traps");

// On an FPAC CPU, you get traps whether you want them or not: there's
// no point in emitting checks or traps.
if (STI->hasFPAC())
ShouldCheck = ShouldTrap = false;

// However, command-line flags can override this, for experimentation.
switch (PtrauthAuthChecks) {
case PtrauthCheckMode::Default:
break;
case PtrauthCheckMode::Unchecked:
ShouldCheck = ShouldTrap = false;
break;
case PtrauthCheckMode::Poison:
ShouldCheck = true;
ShouldTrap = false;
break;
case PtrauthCheckMode::Trap:
ShouldCheck = ShouldTrap = true;
break;
}

auto AUTKey = (AArch64PACKey::ID)MI->getOperand(0).getImm();
uint64_t AUTDisc = MI->getOperand(1).getImm();
unsigned AUTAddrDisc = MI->getOperand(2).getReg();

unsigned XPACOpc = getXPACOpcodeForKey(AUTKey);

// Compute aut discriminator into x17
assert(isUInt<16>(AUTDisc));
unsigned AUTDiscReg =
emitPtrauthDiscriminator(AUTDisc, AUTAddrDisc, InstsEmitted);
bool AUTZero = AUTDiscReg == AArch64::XZR;
unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero);

// autiza x16 ; if AUTZero
// autia x16, x17 ; if !AUTZero
MCInst AUTInst;
AUTInst.setOpcode(AUTOpc);
AUTInst.addOperand(MCOperand::createReg(AArch64::X16));
AUTInst.addOperand(MCOperand::createReg(AArch64::X16));
if (!AUTZero)
AUTInst.addOperand(MCOperand::createReg(AUTDiscReg));
EmitToStreamer(*OutStreamer, AUTInst);
++InstsEmitted;

// Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done.
if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) {
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
return;
}

MCSymbol *EndSym = nullptr;

// Checked sequences do an additional strip-and-compare.
if (ShouldCheck) {
MCSymbol *SuccessSym = createTempSymbol("auth_success_");

// XPAC has tied src/dst: use x17 as a temporary copy.
// mov x17, x16
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
.addReg(AArch64::X17)
.addReg(AArch64::XZR)
.addReg(AArch64::X16)
.addImm(0));
++InstsEmitted;

// xpaci x17
EmitToStreamer(
*OutStreamer,
MCInstBuilder(XPACOpc).addReg(AArch64::X17).addReg(AArch64::X17));
++InstsEmitted;

// cmp x16, x17
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0));
++InstsEmitted;

// b.eq Lsuccess
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(
SuccessSym, OutContext)));
++InstsEmitted;

if (ShouldTrap) {
// Trapping sequences do a 'brk'.
// brk #<0xc470 + aut key>
EmitToStreamer(*OutStreamer,
MCInstBuilder(AArch64::BRK).addImm(0xc470 | AUTKey));
++InstsEmitted;
} else {
// Non-trapping checked sequences return the stripped result in x16,
// skipping over the PAC if there is one.

// FIXME: can we simply return the AUT result, already in x16? without..
// ..traps this is usable as an oracle anyway, based on high bits
// mov x17, x16
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
.addReg(AArch64::X16)
.addReg(AArch64::XZR)
.addReg(AArch64::X17)
.addImm(0));
++InstsEmitted;

if (IsAUTPAC) {
EndSym = createTempSymbol("resign_end_");

// b Lend
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::B)
.addExpr(MCSymbolRefExpr::create(
EndSym, OutContext)));
++InstsEmitted;
}
}

// If the auth check succeeds, we can continue.
// Lsuccess:
OutStreamer->emitLabel(SuccessSym);
}

// We already emitted unchecked and checked-but-non-trapping AUTs.
// That left us with trapping AUTs, and AUTPACs.
// Trapping AUTs don't need PAC: we're done.
if (!IsAUTPAC) {
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
return;
}

auto PACKey = (AArch64PACKey::ID)MI->getOperand(3).getImm();
uint64_t PACDisc = MI->getOperand(4).getImm();
unsigned PACAddrDisc = MI->getOperand(5).getReg();

// Compute pac discriminator into x17
assert(isUInt<16>(PACDisc));
unsigned PACDiscReg =
emitPtrauthDiscriminator(PACDisc, PACAddrDisc, InstsEmitted);
bool PACZero = PACDiscReg == AArch64::XZR;
unsigned PACOpc = getPACOpcodeForKey(PACKey, PACZero);

// pacizb x16 ; if PACZero
// pacib x16, x17 ; if !PACZero
MCInst PACInst;
PACInst.setOpcode(PACOpc);
PACInst.addOperand(MCOperand::createReg(AArch64::X16));
PACInst.addOperand(MCOperand::createReg(AArch64::X16));
if (!PACZero)
PACInst.addOperand(MCOperand::createReg(PACDiscReg));
EmitToStreamer(*OutStreamer, PACInst);
++InstsEmitted;

assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
// Lend:
if (EndSym)
OutStreamer->emitLabel(EndSym);
}

void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
unsigned InstsEmitted = 0;
unsigned BrTarget = MI->getOperand(0).getReg();
Expand Down Expand Up @@ -2056,6 +2285,11 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
return;
}

case AArch64::AUT:
case AArch64::AUTPAC:
emitPtrauthAuthResign(MI);
return;

case AArch64::LOADauthptrstatic:
LowerLOADauthptrstatic(*MI);
return;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64Features.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def FeatureJS : ExtensionWithMArch<"jsconv", "JS", "FEAT_JSCVT",
"Enable Armv8.3-A JavaScript FP conversion instructions",
[FeatureFPARMv8]>;

def FeatureFPAC : Extension<"fpac", "FPAC", "FEAT_FPAC",
"Enable v8.3-A Pointer Authentication Faulting enhancement">;

def FeatureCCIDX : Extension<"ccidx", "CCIDX", "FEAT_CCIDX",
"Enable Armv8.3-A Extend of the CCSIDR number of sets">;

Expand Down
102 changes: 102 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {

bool tryIndexedLoad(SDNode *N);

void SelectPtrauthAuth(SDNode *N);
void SelectPtrauthResign(SDNode *N);

bool trySelectStackSlotTagP(SDNode *N);
void SelectTagP(SDNode *N);

Expand Down Expand Up @@ -1481,6 +1484,96 @@ void AArch64DAGToDAGISel::SelectTable(SDNode *N, unsigned NumVecs, unsigned Opc,
ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, Ops));
}

static std::tuple<SDValue, SDValue>
extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
SDLoc DL(Disc);
SDValue AddrDisc;
SDValue ConstDisc;

// If this is a blend, remember the constant and address discriminators.
// Otherwise, it's either a constant discriminator, or a non-blended
// address discriminator.
if (Disc->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
Disc->getConstantOperandVal(0) == Intrinsic::ptrauth_blend) {
AddrDisc = Disc->getOperand(1);
ConstDisc = Disc->getOperand(2);
} else {
ConstDisc = Disc;
}

// If the constant discriminator (either the blend RHS, or the entire
// discriminator value) isn't a 16-bit constant, bail out, and let the
// discriminator be computed separately.
auto *ConstDiscN = dyn_cast<ConstantSDNode>(ConstDisc);
if (!ConstDiscN || !isUInt<16>(ConstDiscN->getZExtValue()))
return std::make_tuple(DAG->getTargetConstant(0, DL, MVT::i64), Disc);

// If there's no address discriminator, use XZR directly.
if (!AddrDisc)
AddrDisc = DAG->getRegister(AArch64::XZR, MVT::i64);

return std::make_tuple(
DAG->getTargetConstant(ConstDiscN->getZExtValue(), DL, MVT::i64),
AddrDisc);
}

void AArch64DAGToDAGISel::SelectPtrauthAuth(SDNode *N) {
SDLoc DL(N);
// IntrinsicID is operand #0
SDValue Val = N->getOperand(1);
SDValue AUTKey = N->getOperand(2);
SDValue AUTDisc = N->getOperand(3);

unsigned AUTKeyC = cast<ConstantSDNode>(AUTKey)->getZExtValue();
AUTKey = CurDAG->getTargetConstant(AUTKeyC, DL, MVT::i64);

SDValue AUTAddrDisc, AUTConstDisc;
std::tie(AUTConstDisc, AUTAddrDisc) =
extractPtrauthBlendDiscriminators(AUTDisc, CurDAG);

SDValue X16Copy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
AArch64::X16, Val, SDValue());
SDValue Ops[] = {AUTKey, AUTConstDisc, AUTAddrDisc, X16Copy.getValue(1)};

SDNode *AUT = CurDAG->getMachineNode(AArch64::AUT, DL, MVT::i64, Ops);
ReplaceNode(N, AUT);
return;
}

void AArch64DAGToDAGISel::SelectPtrauthResign(SDNode *N) {
SDLoc DL(N);
// IntrinsicID is operand #0
SDValue Val = N->getOperand(1);
SDValue AUTKey = N->getOperand(2);
SDValue AUTDisc = N->getOperand(3);
SDValue PACKey = N->getOperand(4);
SDValue PACDisc = N->getOperand(5);

unsigned AUTKeyC = cast<ConstantSDNode>(AUTKey)->getZExtValue();
unsigned PACKeyC = cast<ConstantSDNode>(PACKey)->getZExtValue();

AUTKey = CurDAG->getTargetConstant(AUTKeyC, DL, MVT::i64);
PACKey = CurDAG->getTargetConstant(PACKeyC, DL, MVT::i64);

SDValue AUTAddrDisc, AUTConstDisc;
std::tie(AUTConstDisc, AUTAddrDisc) =
extractPtrauthBlendDiscriminators(AUTDisc, CurDAG);

SDValue PACAddrDisc, PACConstDisc;
std::tie(PACConstDisc, PACAddrDisc) =
extractPtrauthBlendDiscriminators(PACDisc, CurDAG);

SDValue X16Copy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
AArch64::X16, Val, SDValue());

SDValue Ops[] = {AUTKey, AUTConstDisc, AUTAddrDisc, PACKey,
PACConstDisc, PACAddrDisc, X16Copy.getValue(1)};

SDNode *AUTPAC = CurDAG->getMachineNode(AArch64::AUTPAC, DL, MVT::i64, Ops);
ReplaceNode(N, AUTPAC);
return;
}

bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);
if (LD->isUnindexed())
Expand Down Expand Up @@ -5437,6 +5530,15 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
case Intrinsic::aarch64_tagp:
SelectTagP(Node);
return;

case Intrinsic::ptrauth_auth:
SelectPtrauthAuth(Node);
return;

case Intrinsic::ptrauth_resign:
SelectPtrauthResign(Node);
return;

case Intrinsic::aarch64_neon_tbl2:
SelectTable(Node, 2,
VT == MVT::v8i8 ? AArch64::TBLv8i8Two : AArch64::TBLv16i8Two,
Expand Down
Loading
Loading