Skip to content

[RISCV] Initial ISel support for the experimental zacas extension #67918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class RISCVExpandAtomicPseudo : public MachineFunctionPass {
bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, bool IsMasked,
int Width, MachineBasicBlock::iterator &NextMBBI);
bool expandAMOCAS(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool IsPaired, int Width,
MachineBasicBlock::iterator &NextMBBI);
#ifndef NDEBUG
unsigned getInstSizeInBytes(const MachineFunction &MF) const {
unsigned Size = 0;
Expand Down Expand Up @@ -145,6 +148,14 @@ bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
return expandAtomicCmpXchg(MBB, MBBI, false, 64, NextMBBI);
case RISCV::PseudoMaskedCmpXchg32:
return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
case RISCV::PseudoAMOCAS_W:
return expandAMOCAS(MBB, MBBI, false, 32, NextMBBI);
case RISCV::PseudoAMOCAS_D_RV64:
return expandAMOCAS(MBB, MBBI, false, 64, NextMBBI);
case RISCV::PseudoAMOCAS_D_RV32:
return expandAMOCAS(MBB, MBBI, true, 64, NextMBBI);
case RISCV::PseudoAMOCAS_Q:
return expandAMOCAS(MBB, MBBI, true, 128, NextMBBI);
}

return false;
Expand Down Expand Up @@ -256,6 +267,74 @@ static unsigned getSCForRMW(AtomicOrdering Ordering, int Width,
llvm_unreachable("Unexpected SC width\n");
}

static unsigned getAMOCASForRMW32(AtomicOrdering Ordering,
const RISCVSubtarget *Subtarget) {
if (Subtarget->hasStdExtZtso())
return RISCV::AMOCAS_W;
switch (Ordering) {
default:
llvm_unreachable("Unexpected AtomicOrdering");
case AtomicOrdering::Monotonic:
return RISCV::AMOCAS_W;
case AtomicOrdering::Acquire:
return RISCV::AMOCAS_W_AQ;
case AtomicOrdering::Release:
return RISCV::AMOCAS_W_RL;
case AtomicOrdering::AcquireRelease:
case AtomicOrdering::SequentiallyConsistent:
return RISCV::AMOCAS_W_AQ_RL;
}
}

static unsigned getAMOCASForRMW64(AtomicOrdering Ordering,
const RISCVSubtarget *Subtarget) {
if (Subtarget->hasStdExtZtso())
return RISCV::AMOCAS_D;
switch (Ordering) {
default:
llvm_unreachable("Unexpected AtomicOrdering");
case AtomicOrdering::Monotonic:
return RISCV::AMOCAS_D;
case AtomicOrdering::Acquire:
return RISCV::AMOCAS_D_AQ;
case AtomicOrdering::Release:
return RISCV::AMOCAS_D_RL;
case AtomicOrdering::AcquireRelease:
case AtomicOrdering::SequentiallyConsistent:
return RISCV::AMOCAS_D_AQ_RL;
}
}

static unsigned getAMOCASForRMW128(AtomicOrdering Ordering,
const RISCVSubtarget *Subtarget) {
if (Subtarget->hasStdExtZtso())
return RISCV::AMOCAS_Q;
switch (Ordering) {
default:
llvm_unreachable("Unexpected AtomicOrdering");
case AtomicOrdering::Monotonic:
return RISCV::AMOCAS_Q;
case AtomicOrdering::Acquire:
return RISCV::AMOCAS_Q_AQ;
case AtomicOrdering::Release:
return RISCV::AMOCAS_Q_RL;
case AtomicOrdering::AcquireRelease:
case AtomicOrdering::SequentiallyConsistent:
return RISCV::AMOCAS_Q_AQ_RL;
}
}

static unsigned getAMOCASForRMW(AtomicOrdering Ordering, int Width,
const RISCVSubtarget *Subtarget) {
if (Width == 32)
return getAMOCASForRMW32(Ordering, Subtarget);
if (Width == 64)
return getAMOCASForRMW64(Ordering, Subtarget);
if (Width == 128)
return getAMOCASForRMW128(Ordering, Subtarget);
llvm_unreachable("Unexpected AMOCAS width\n");
}

static void doAtomicBinOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
DebugLoc DL, MachineBasicBlock *ThisMBB,
MachineBasicBlock *LoopMBB,
Expand Down Expand Up @@ -728,6 +807,38 @@ bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
return true;
}

static Register getGPRPairEvenReg(Register PairedReg) {
assert(PairedReg >= RISCV::X0_PD && PairedReg <= RISCV::X30_PD &&
"Invalid GPR pair");
return (PairedReg - RISCV::X0_PD) * 2 + RISCV::X0;
}

bool RISCVExpandAtomicPseudo::expandAMOCAS(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsPaired,
int Width, MachineBasicBlock::iterator &NextMBBI) {
MachineInstr &MI = *MBBI;
DebugLoc DL = MI.getDebugLoc();

Register DestReg = MI.getOperand(0).getReg();
if (IsPaired)
DestReg = getGPRPairEvenReg(DestReg);
Register AddrReg = MI.getOperand(1).getReg();
Register NewValReg = MI.getOperand(3).getReg();
if (IsPaired)
NewValReg = getGPRPairEvenReg(NewValReg);
AtomicOrdering Ordering =
static_cast<AtomicOrdering>(MI.getOperand(4).getImm());

MachineInstr *NewMI =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable

BuildMI(MBB, MBBI, DL, TII->get(getAMOCASForRMW(Ordering, Width, STI)))
.addReg(DestReg, RegState::Define)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add an implicit def and implicit use of the original paired registers. Otherwise any passes after this will seen incorrect register liveness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any later pass for RISCV target? Is RISCVExpandAtomicPseudoInsts the last pass in the pipeline, isn't it?

Copy link
Collaborator

@topperc topperc Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is, but passes shouldn't assume their order. I'm surprised it didn't confuse the MachineVerifier

.addReg(AddrReg)
.addReg(NewValReg);

MI.eraseFromParent();
return true;
}

} // end of anonymous namespace

INITIALIZE_PASS(RISCVExpandAtomicPseudo, "riscv-expand-atomic-pseudo",
Expand Down
65 changes: 64 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}

if (Subtarget.hasStdExtA()) {
setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
if (Subtarget.hasStdExtZacas())
setMaxAtomicSizeInBitsSupported(Subtarget.getXLen() * 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change affects the behavior of atomicrmw and atomic load/store for 2*XLen types. I'm working my own version of this patch that will address this.

else
setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
setMinCmpXchgSizeInBits(32);
} else if (Subtarget.hasForcedAtomics()) {
setMaxAtomicSizeInBitsSupported(Subtarget.getXLen());
Expand Down Expand Up @@ -1339,6 +1342,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
XLenVT, LibCall);
}

// Set atomic_cmp_swap operations to expand to AMOCAS.D (RV32) and AMOCAS.Q
// (RV64).
if (Subtarget.hasStdExtZacas())
setOperationAction(ISD::ATOMIC_CMP_SWAP,
Subtarget.is64Bit() ? MVT::i128 : MVT::i64, Custom);

if (Subtarget.hasVendorXTHeadMemIdx()) {
for (unsigned im = (unsigned)ISD::PRE_INC; im != (unsigned)ISD::POST_DEC;
++im) {
Expand Down Expand Up @@ -11075,13 +11084,67 @@ static SDValue customLegalizeToWOpWithSExt(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
}

// Create an even/odd pair of X registers holding integer value V.
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V, MVT VT,
MVT SubRegVT) {
SDLoc DL(V.getNode());
auto [VLo, VHi] = DAG.SplitScalar(V, DL, SubRegVT, SubRegVT);
SDValue RegClass =
DAG.getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32);
SDValue SubReg0 = DAG.getTargetConstant(RISCV::sub_32, DL, MVT::i32);
SDValue SubReg1 = DAG.getTargetConstant(RISCV::sub_32_hi, DL, MVT::i32);
const SDValue Ops[] = {RegClass, VLo, SubReg0, VHi, SubReg1};
return SDValue(
DAG.getMachineNode(TargetOpcode::REG_SEQUENCE, DL, MVT::Untyped, Ops), 0);
}

static void ReplaceCMP_SWAP_2XLenResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = N->getSimpleValueType(0);
assert(N->getValueType(0) == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
"AtomicCmpSwap on types less than 2*XLen should be legal");
assert(Subtarget.hasStdExtZacas());
MVT XLenVT = Subtarget.getXLenVT();

SDLoc DL(N);
MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
AtomicOrdering Ordering = MemOp->getMergedOrdering();
SDValue Ops[] = {
N->getOperand(1), // Ptr
createGPRPairNode(DAG, N->getOperand(2), VT, XLenVT), // Compare value
createGPRPairNode(DAG, N->getOperand(3), VT, XLenVT), // Store value
DAG.getTargetConstant(static_cast<unsigned>(Ordering), DL,
MVT::i32), // Ordering
N->getOperand(0), // Chain in
};

unsigned Opcode =
(VT == MVT::i64 ? RISCV::PseudoAMOCAS_D_RV32 : RISCV::PseudoAMOCAS_Q);
MachineSDNode *CmpSwap = DAG.getMachineNode(
Opcode, DL, DAG.getVTList(MVT::Untyped, MVT::Other), Ops);
DAG.setNodeMemRefs(CmpSwap, {MemOp});

unsigned SubReg1 = RISCV::sub_32, SubReg2 = RISCV::sub_32_hi;
SDValue Lo =
DAG.getTargetExtractSubreg(SubReg1, DL, XLenVT, SDValue(CmpSwap, 0));
SDValue Hi =
DAG.getTargetExtractSubreg(SubReg2, DL, XLenVT, SDValue(CmpSwap, 0));
Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, VT, Lo, Hi));
Results.push_back(SDValue(CmpSwap, 1));
}

void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const {
SDLoc DL(N);
switch (N->getOpcode()) {
default:
llvm_unreachable("Don't know how to custom type legalize this operation!");
case ISD::ATOMIC_CMP_SWAP:
ReplaceCMP_SWAP_2XLenResults(N, Results, DAG, Subtarget);
break;
case ISD::STRICT_FP_TO_SINT:
case ISD::STRICT_FP_TO_UINT:
case ISD::FP_TO_SINT:
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoA.td
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,28 @@ multiclass PseudoCmpXchgPat<string Op, Pseudo CmpXchgInst,
(CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 7)>;
}

let Predicates = [HasStdExtZacas] in {
class PseudoAMOCAS<RegisterClass RC = GPR>
: Pseudo<(outs RC:$res),
(ins GPR:$addr, RC:$cmpval, RC:$newval, ixlenimm:$ordering), []> {
let Constraints = "$res = $cmpval";
let mayLoad = 1;
let mayStore = 1;
let hasSideEffects = 0;
}
def PseudoAMOCAS_W: PseudoAMOCAS;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_32", PseudoAMOCAS_W>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we directly select PseudoAMOCAS_W and PseudoAMOCAS_D_RV64 without going through a pseudoinstruction? Similar to AMOPat?


let Predicates = [HasStdExtZacas, IsRV32] in
def PseudoAMOCAS_D_RV32: PseudoAMOCAS<GPRPair>;

let Predicates = [HasStdExtZacas, IsRV64] in {
def PseudoAMOCAS_D_RV64: PseudoAMOCAS;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_64", PseudoAMOCAS_D_RV64>;
def PseudoAMOCAS_Q: PseudoAMOCAS<GPRPair>;
}
}

def PseudoCmpXchg32 : PseudoCmpXchg;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_32", PseudoCmpXchg32>;

Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,16 @@ def GPRPF64 : RegisterClass<"RISCV", [f64], 64, (add
X0_PD, X2_PD, X4_PD
)>;

let RegInfos = RegInfoByHwMode<[RV32, RV64], [RegInfo<32, 32, 32>, RegInfo<64, 64, 64>]> in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these RegInfos are correct. I think they should be RegInfo<64, 64, 64>, RegInfo<128, 128, 128>.

def GPRPair : RegisterClass<"RISCV", [untyped], 64, (add
X10_PD, X12_PD, X14_PD, X16_PD,
X6_PD,
X28_PD, X30_PD,
X8_PD,
X18_PD, X20_PD, X22_PD, X24_PD, X26_PD,
X0_PD, X2_PD, X4_PD
)>;

// The register class is added for inline assembly for vector mask types.
def VM : VReg<VMaskVTs,
(add (sequence "V%u", 8, 31),
Expand Down
Loading