Skip to content

Commit 5fefddd

Browse files
committed
[LoongArch] Enable tail calls for sret functions and relax argument matching
Allow tail-calling functions that return via sret when the caller has an incoming sret pointer that can be forwarded. Remove the overly strict requirement that tail-call argument values must exactly match the caller's incoming arguments. The real constraint is only that the callee uses no more argument stack space than the caller. This fixes musttail codegen and enables significantly more tail-call optimizations.
1 parent 90e1391 commit 5fefddd

File tree

5 files changed

+479
-19
lines changed

5 files changed

+479
-19
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8069,6 +8069,7 @@ SDValue LoongArchTargetLowering::LowerFormalArguments(
80698069
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
80708070

80718071
MachineFunction &MF = DAG.getMachineFunction();
8072+
auto *LoongArchFI = MF.getInfo<LoongArchMachineFunctionInfo>();
80728073

80738074
switch (CallConv) {
80748075
default:
@@ -8140,7 +8141,6 @@ SDValue LoongArchTargetLowering::LowerFormalArguments(
81408141
const TargetRegisterClass *RC = &LoongArch::GPRRegClass;
81418142
MachineFrameInfo &MFI = MF.getFrameInfo();
81428143
MachineRegisterInfo &RegInfo = MF.getRegInfo();
8143-
auto *LoongArchFI = MF.getInfo<LoongArchMachineFunctionInfo>();
81448144

81458145
// Offset of the first variable argument from stack pointer, and size of
81468146
// the vararg save area. For now, the varargs save area is either zero or
@@ -8190,6 +8190,8 @@ SDValue LoongArchTargetLowering::LowerFormalArguments(
81908190
LoongArchFI->setVarArgsSaveSize(VarArgsSaveSize);
81918191
}
81928192

8193+
LoongArchFI->setArgumentStackSize(CCInfo.getStackSize());
8194+
81938195
// All stores are grouped in one node to allow the matching between
81948196
// the size of Ins and InVals. This only happens for vararg functions.
81958197
if (!OutChains.empty()) {
@@ -8246,9 +8248,11 @@ bool LoongArchTargetLowering::isEligibleForTailCallOptimization(
82468248
auto &Outs = CLI.Outs;
82478249
auto &Caller = MF.getFunction();
82488250
auto CallerCC = Caller.getCallingConv();
8251+
auto *LoongArchFI = MF.getInfo<LoongArchMachineFunctionInfo>();
82498252

8250-
// Do not tail call opt if the stack is used to pass parameters.
8251-
if (CCInfo.getStackSize() != 0)
8253+
// If the stack arguments for this call do not fit into our own save area then
8254+
// the call cannot be made tail.
8255+
if (CCInfo.getStackSize() > LoongArchFI->getArgumentStackSize())
82528256
return false;
82538257

82548258
// Do not tail call opt if any parameters need to be passed indirectly.
@@ -8260,7 +8264,7 @@ bool LoongArchTargetLowering::isEligibleForTailCallOptimization(
82608264
// semantics.
82618265
auto IsCallerStructRet = Caller.hasStructRetAttr();
82628266
auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
8263-
if (IsCallerStructRet || IsCalleeStructRet)
8267+
if (IsCallerStructRet != IsCalleeStructRet)
82648268
return false;
82658269

82668270
// Do not tail call opt if either the callee or caller has a byval argument.
@@ -8276,9 +8280,47 @@ bool LoongArchTargetLowering::isEligibleForTailCallOptimization(
82768280
if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
82778281
return false;
82788282
}
8283+
8284+
// If the callee takes no arguments then go on to check the results of the
8285+
// call.
8286+
const MachineRegisterInfo &MRI = MF.getRegInfo();
8287+
const SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
8288+
if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
8289+
return false;
8290+
82798291
return true;
82808292
}
82818293

8294+
SDValue LoongArchTargetLowering::addTokenForArgument(SDValue Chain,
8295+
SelectionDAG &DAG,
8296+
MachineFrameInfo &MFI,
8297+
int ClobberedFI) const {
8298+
SmallVector<SDValue, 8> ArgChains;
8299+
int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
8300+
int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;
8301+
8302+
// Include the original chain at the beginning of the list. When this is
8303+
// used by target LowerCall hooks, this helps legalize find the
8304+
// CALLSEQ_BEGIN node.
8305+
ArgChains.push_back(Chain);
8306+
8307+
// Add a chain value for each stack argument corresponding
8308+
for (SDNode *U : DAG.getEntryNode().getNode()->users())
8309+
if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
8310+
if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
8311+
if (FI->getIndex() < 0) {
8312+
int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
8313+
int64_t InLastByte = InFirstByte;
8314+
InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;
8315+
8316+
if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
8317+
(FirstByte <= InFirstByte && InFirstByte <= LastByte))
8318+
ArgChains.push_back(SDValue(L, 1));
8319+
}
8320+
8321+
// Build a tokenfactor for all the chains.
8322+
return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
8323+
}
82828324
static Align getPrefTypeAlign(EVT VT, SelectionDAG &DAG) {
82838325
return DAG.getDataLayout().getPrefTypeAlign(
82848326
VT.getTypeForEVT(*DAG.getContext()));
@@ -8454,19 +8496,32 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
84548496
RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
84558497
} else {
84568498
assert(VA.isMemLoc() && "Argument not register or memory");
8457-
assert(!IsTailCall && "Tail call not allowed if stack is used "
8458-
"for passing parameters");
8499+
SDValue DstAddr;
8500+
MachinePointerInfo DstInfo;
8501+
int32_t Offset = VA.getLocMemOffset();
84598502

84608503
// Work out the address of the stack slot.
84618504
if (!StackPtr.getNode())
84628505
StackPtr = DAG.getCopyFromReg(Chain, DL, LoongArch::R3, PtrVT);
8463-
SDValue Address =
8464-
DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
8465-
DAG.getIntPtrConstant(VA.getLocMemOffset(), DL));
8506+
8507+
if (IsTailCall) {
8508+
unsigned OpSize = (VA.getValVT().getSizeInBits() + 7) / 8;
8509+
int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
8510+
DstAddr = DAG.getFrameIndex(FI, PtrVT);
8511+
DstInfo = MachinePointerInfo::getFixedStack(MF, FI);
8512+
// Make sure any stack arguments overlapping with where we're storing
8513+
// are loaded before this eventual operation. Otherwise they'll be
8514+
// clobbered.
8515+
Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
8516+
} else {
8517+
SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
8518+
DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
8519+
DstInfo = MachinePointerInfo::getStack(MF, Offset);
8520+
}
84668521

84678522
// Emit the store.
84688523
MemOpChains.push_back(
8469-
DAG.getStore(Chain, DL, ArgValue, Address, MachinePointerInfo()));
8524+
DAG.getStore(Chain, DL, ArgValue, DstAddr, DstInfo));
84708525
}
84718526
}
84728527

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,12 @@ class LoongArchTargetLowering : public TargetLowering {
438438
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
439439
const SmallVectorImpl<CCValAssign> &ArgLocs) const;
440440

441+
/// Finds the incoming stack arguments which overlap the given fixed stack
442+
/// object and incorporates their load into the current chain. This prevents
443+
/// an upcoming store from clobbering the stack argument before it's used.
444+
SDValue addTokenForArgument(SDValue Chain, SelectionDAG &DAG,
445+
MachineFrameInfo &MFI, int ClobberedFI) const;
446+
441447
bool softPromoteHalfType() const override { return true; }
442448

443449
bool

llvm/lib/Target/LoongArch/LoongArchMachineFunctionInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class LoongArchMachineFunctionInfo : public MachineFunctionInfo {
3232
/// Size of stack frame to save callee saved registers
3333
unsigned CalleeSavedStackSize = 0;
3434

35+
/// ArgumentStackSize - amount of bytes on stack consumed by the arguments
36+
/// being passed on the stack
37+
unsigned ArgumentStackSize = 0;
38+
3539
/// FrameIndex of the spill slot when there is no scavenged register in
3640
/// insertIndirectBranch.
3741
int BranchRelaxationSpillFrameIndex = -1;
@@ -63,6 +67,9 @@ class LoongArchMachineFunctionInfo : public MachineFunctionInfo {
6367
unsigned getCalleeSavedStackSize() const { return CalleeSavedStackSize; }
6468
void setCalleeSavedStackSize(unsigned Size) { CalleeSavedStackSize = Size; }
6569

70+
unsigned getArgumentStackSize() const { return ArgumentStackSize; }
71+
void setArgumentStackSize(unsigned size) { ArgumentStackSize = size; }
72+
6673
int getBranchRelaxationSpillFrameIndex() {
6774
return BranchRelaxationSpillFrameIndex;
6875
}

0 commit comments

Comments
 (0)