Skip to content

Commit d823879

Browse files
committed
[WIP][AMDGPU] Improve the handling of inreg arguments
When SGPRs available for `inreg` argument passing run out, the compiler silently falls back to using whole VGPRs to pass those arguments. Ideally, instead of using whole VGPRs, we should pack `inreg` arguments into individual lanes of VGPRs. This PR introduces `InregVGPRSpiller`, which handles this packing. It uses `v_writelane` at the call site to place `inreg` arguments into specific VGPR lanes, and then extracts them in the callee using `v_readlane`. Fixes #130443 and #129071.
1 parent 7f2abe8 commit d823879

File tree

3 files changed

+232
-9
lines changed

3 files changed

+232
-9
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 168 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,144 @@ void SITargetLowering::insertCopiesSplitCSR(
28412841
}
28422842
}
28432843

2844+
/// Base class for spilling inreg VGPR arguments.
2845+
///
2846+
/// When an argument marked inreg is pushed to a VGPR, it indicates that the
2847+
/// available SGPRs for argument passing have been exhausted. In such cases, it
2848+
/// is preferable to pack multiple inreg arguments into individual lanes of
2849+
/// VGPRs instead of assigning each directly to separate VGPRs.
2850+
///
2851+
/// Spilling involves two parts: the caller-side (call site) and the
2852+
/// callee-side. Both must follow the same method for selecting registers and
2853+
/// lanes, ensuring that an argument written at the call site matches exactly
2854+
/// with the one read at the callee.
2855+
///
2856+
/// \p InregVGPRSpiller::setReg selects the register used for a given argument.
2857+
/// If \p CurReg is invalid, it uses the register determined by the calling
2858+
/// convention. The first inreg VGPR argument is stored into lane 0.
2859+
///
2860+
/// After reading or writing an argument, \p InregVGPRSpiller::forward advances
2861+
/// the lane counter. When all lanes of a VGPR are used, it resets \p CurReg.
2862+
/// Upon the next read/write operation, the register determined by the calling
2863+
/// convention will be selected again, and lane numbering will restart from 0.
2864+
class InregVPGRSpiller {
2865+
CCState &State;
2866+
const unsigned WaveFrontSize;
2867+
2868+
Register CurReg;
2869+
unsigned CurLane = 0;
2870+
2871+
protected:
2872+
SelectionDAG &DAG;
2873+
MachineFunction &MF;
2874+
2875+
Register getCurReg() const { return CurReg; }
2876+
unsigned getCurLane() const { return CurLane % WaveFrontSize; }
2877+
2878+
InregVPGRSpiller(SelectionDAG &DAG, MachineFunction &MF, CCState &State)
2879+
: State(State),
2880+
WaveFrontSize(MF.getSubtarget<GCNSubtarget>().getWavefrontSize()),
2881+
DAG(DAG), MF(MF) {}
2882+
2883+
void setReg(Register &Reg) {
2884+
if (CurReg.isValid()) {
2885+
State.DeallocateReg(Reg);
2886+
Reg = CurReg;
2887+
} else {
2888+
CurReg = Reg;
2889+
}
2890+
}
2891+
2892+
void forward() {
2893+
// FIXME: Wrapping may never occur here, since that would imply at least 32
2894+
// or even 64 inreg arguments, which might exceed ABI limitations.
2895+
if (++CurLane % WaveFrontSize == 0)
2896+
CurReg = 0;
2897+
}
2898+
};
2899+
2900+
/// Base class for spilling inreg VGPR arguments.
2901+
///
2902+
/// When an argument marked inreg is pushed to a VGPR, it indicates that the
2903+
/// available SGPRs for argument passing have been exhausted. In such cases, it
2904+
/// is preferable to pack multiple inreg arguments into individual lanes of
2905+
/// VGPRs instead of assigning each directly to separate VGPRs.
2906+
///
2907+
/// Spilling involves two parts: the caller-side (call site) and the
2908+
/// callee-side. Both must follow the same method for selecting registers and
2909+
/// lanes, ensuring that an argument written at the call site matches exactly
2910+
/// with the one read at the callee.
2911+
class InregVPGRSpillerCallee {
2912+
CCState &State;
2913+
SelectionDAG &DAG;
2914+
MachineFunction &MF;
2915+
2916+
Register SrcReg;
2917+
SDValue SrcVal;
2918+
unsigned CurLane = 0;
2919+
2920+
public:
2921+
InregVPGRSpillerCallee(SelectionDAG &DAG, MachineFunction &MF, CCState &State)
2922+
: State(State), DAG(DAG), MF(MF) {}
2923+
2924+
SDValue read(SDValue Chain, const SDLoc &SL, Register &Reg, EVT VT) {
2925+
if (SrcVal) {
2926+
State.DeallocateReg(Reg);
2927+
} else {
2928+
Reg = MF.addLiveIn(Reg, &AMDGPU::VGPR_32RegClass);
2929+
SrcReg = Reg;
2930+
SrcVal = DAG.getCopyFromReg(Chain, SL, Reg, VT);
2931+
}
2932+
// According to the calling convention, only SGPR4–SGPR29 should be used for
2933+
// passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
2934+
// arguments must not exceed 26.
2935+
assert(CurLane < 26 && "more than expected VGPR inreg arguments");
2936+
SmallVector<SDValue, 4> Operands{
2937+
DAG.getTargetConstant(Intrinsic::amdgcn_readlane, SL, MVT::i32),
2938+
DAG.getRegister(SrcReg, VT),
2939+
DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
2940+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands);
2941+
}
2942+
};
2943+
2944+
/// The spilling class for the caller-side that lowers packing of call site
2945+
/// arguments.
2946+
class InregVPGRSpillerCallSite {
2947+
CCState &State;
2948+
2949+
Register DstReg;
2950+
SDValue Glue;
2951+
unsigned CurLane = 0;
2952+
2953+
SelectionDAG &DAG;
2954+
MachineFunction &MF;
2955+
2956+
public:
2957+
InregVPGRSpillerCallSite(SelectionDAG &DAG, MachineFunction &MF,
2958+
CCState &State)
2959+
: State(State), DAG(DAG), MF(MF) {}
2960+
2961+
std::pair<SDValue, SDValue> write(SDValue Chain, const SDLoc &SL,
2962+
Register &Reg, SDValue Val, SDValue InGlue,
2963+
EVT VT) {
2964+
if (DstReg.isValid()) {
2965+
Reg = DstReg;
2966+
} else {
2967+
DstReg = Reg;
2968+
Glue = DAG.getCopyToReg(Chain, SL, Reg, Val, InGlue).getValue(1);
2969+
}
2970+
// According to the calling convention, only SGPR4–SGPR29 should be used for
2971+
// passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
2972+
// arguments must not exceed 26.
2973+
assert(CurLane < 26 && "more than expected VGPR inreg arguments");
2974+
SmallVector<SDValue, 4> Operands{
2975+
DAG.getTargetConstant(Intrinsic::amdgcn_writelane, SL, MVT::i32),
2976+
DAG.getRegister(DstReg, VT), Val,
2977+
DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
2978+
return {DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands), Glue};
2979+
}
2980+
};
2981+
28442982
SDValue SITargetLowering::LowerFormalArguments(
28452983
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
28462984
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -2963,6 +3101,7 @@ SDValue SITargetLowering::LowerFormalArguments(
29633101
// FIXME: Alignment of explicit arguments totally broken with non-0 explicit
29643102
// kern arg offset.
29653103
const Align KernelArgBaseAlign = Align(16);
3104+
InregVPGRSpillerCallee Spiller(DAG, MF, CCInfo);
29663105

29673106
for (unsigned i = 0, e = Ins.size(), ArgIdx = 0; i != e; ++i) {
29683107
const ISD::InputArg &Arg = Ins[i];
@@ -3130,8 +3269,17 @@ SDValue SITargetLowering::LowerFormalArguments(
31303269
llvm_unreachable("Unexpected register class in LowerFormalArguments!");
31313270
EVT ValVT = VA.getValVT();
31323271

3133-
Reg = MF.addLiveIn(Reg, RC);
3134-
SDValue Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
3272+
SDValue Val;
3273+
// If an argument is marked inreg but gets pushed to a VGPR, it indicates
3274+
// we've run out of SGPRs for argument passing. In such cases, we'd prefer
3275+
// to start packing inreg arguments into individual lanes of VGPRs, rather
3276+
// than placing them directly into VGPRs.
3277+
if (RC == &AMDGPU::VGPR_32RegClass && Arg.Flags.isInReg()) {
3278+
Val = Spiller.read(Chain, DL, Reg, VT);
3279+
} else {
3280+
Reg = MF.addLiveIn(Reg, RC);
3281+
Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
3282+
}
31353283

31363284
if (Arg.Flags.isSRet()) {
31373285
// The return object should be reasonably addressable.
@@ -3373,7 +3521,7 @@ SDValue SITargetLowering::LowerCallResult(
33733521
// from the explicit user arguments present in the IR.
33743522
void SITargetLowering::passSpecialInputs(
33753523
CallLoweringInfo &CLI, CCState &CCInfo, const SIMachineFunctionInfo &Info,
3376-
SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
3524+
SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
33773525
SmallVectorImpl<SDValue> &MemOpChains, SDValue Chain) const {
33783526
// If we don't have a call site, this was a call inserted by
33793527
// legalization. These can never use special inputs.
@@ -3817,7 +3965,7 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
38173965
}
38183966

38193967
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
3820-
SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
3968+
SmallVector<std::pair<Register, SDValue>, 8> RegsToPass;
38213969
SmallVector<SDValue, 8> MemOpChains;
38223970

38233971
// Analyze operands of the call, assigning locations to each operand.
@@ -3875,6 +4023,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
38754023

38764024
MVT PtrVT = MVT::i32;
38774025

4026+
InregVPGRSpillerCallSite Spiller(DAG, MF, CCInfo);
4027+
38784028
// Walk the register/memloc assignments, inserting copies/loads.
38794029
for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
38804030
CCValAssign &VA = ArgLocs[i];
@@ -3988,8 +4138,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
39884138
SDValue InGlue;
39894139

39904140
unsigned ArgIdx = 0;
3991-
for (auto [Reg, Val] : RegsToPass) {
3992-
if (ArgIdx++ >= NumSpecialInputs &&
4141+
for (auto &[Reg, Val] : RegsToPass) {
4142+
if (ArgIdx >= NumSpecialInputs &&
39934143
(IsChainCallConv || !Val->isDivergent()) && TRI->isSGPRPhysReg(Reg)) {
39944144
// For chain calls, the inreg arguments are required to be
39954145
// uniform. Speculatively Insert a readfirstlane in case we cannot prove
@@ -4008,8 +4158,18 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
40084158
ReadfirstlaneArgs);
40094159
}
40104160

4011-
Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
4012-
InGlue = Chain.getValue(1);
4161+
if (ArgIdx >= NumSpecialInputs &&
4162+
Outs[ArgIdx - NumSpecialInputs].Flags.isInReg() &&
4163+
AMDGPU::VGPR_32RegClass.contains(Reg)) {
4164+
std::tie(Chain, InGlue) =
4165+
Spiller.write(Chain, DL, Reg, Val, InGlue,
4166+
ArgLocs[ArgIdx - NumSpecialInputs].getLocVT());
4167+
} else {
4168+
Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
4169+
InGlue = Chain.getValue(1);
4170+
}
4171+
4172+
++ArgIdx;
40134173
}
40144174

40154175
// We don't usually want to end the call-sequence here because we would tidy

llvm/lib/Target/AMDGPU/SIISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
406406
CallLoweringInfo &CLI,
407407
CCState &CCInfo,
408408
const SIMachineFunctionInfo &Info,
409-
SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
409+
SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
410410
SmallVectorImpl<SDValue> &MemOpChains,
411411
SDValue Chain) const;
412412

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx950 -o - %s | FileCheck %s
3+
4+
; arg3 is v0, arg4 is in v1. These should be packed into a lane and extracted with readlane
5+
define i32 @callee(<8 x i32> inreg %arg0, <8 x i32> inreg %arg1, <2 x i32> inreg %arg2, i32 inreg %arg3, i32 inreg %arg4) {
6+
; CHECK-LABEL: callee:
7+
; CHECK: ; %bb.0:
8+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
9+
; CHECK-NEXT: v_readlane_b32 s0, v0, 1
10+
; CHECK-NEXT: v_readlane_b32 s1, v0, 0
11+
; CHECK-NEXT: s_sub_i32 s0, s1, s0
12+
; CHECK-NEXT: v_mov_b32_e32 v0, s0
13+
; CHECK-NEXT: s_setpc_b64 s[30:31]
14+
%add = sub i32 %arg3, %arg4
15+
ret i32 %add
16+
}
17+
18+
define amdgpu_kernel void @kernel(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4, ptr %p) {
19+
; CHECK-LABEL: kernel:
20+
; CHECK: ; %bb.0:
21+
; CHECK-NEXT: s_mov_b32 s12, s8
22+
; CHECK-NEXT: s_add_u32 s8, s4, 0x58
23+
; CHECK-NEXT: s_mov_b32 s13, s9
24+
; CHECK-NEXT: s_addc_u32 s9, s5, 0
25+
; CHECK-NEXT: s_load_dwordx16 s[36:51], s[4:5], 0x0
26+
; CHECK-NEXT: s_load_dwordx4 s[28:31], s[4:5], 0x40
27+
; CHECK-NEXT: s_load_dwordx2 s[34:35], s[4:5], 0x50
28+
; CHECK-NEXT: s_getpc_b64 s[4:5]
29+
; CHECK-NEXT: s_add_u32 s4, s4, callee@gotpcrel32@lo+4
30+
; CHECK-NEXT: s_addc_u32 s5, s5, callee@gotpcrel32@hi+12
31+
; CHECK-NEXT: s_load_dwordx2 s[52:53], s[4:5], 0x0
32+
; CHECK-NEXT: s_mov_b32 s14, s10
33+
; CHECK-NEXT: s_mov_b64 s[10:11], s[6:7]
34+
; CHECK-NEXT: s_mov_b64 s[4:5], s[0:1]
35+
; CHECK-NEXT: s_mov_b64 s[6:7], s[2:3]
36+
; CHECK-NEXT: v_mov_b32_e32 v31, v0
37+
; CHECK-NEXT: s_waitcnt lgkmcnt(0)
38+
; CHECK-NEXT: s_mov_b32 s0, s36
39+
; CHECK-NEXT: s_mov_b32 s1, s37
40+
; CHECK-NEXT: s_mov_b32 s2, s38
41+
; CHECK-NEXT: s_mov_b32 s3, s39
42+
; CHECK-NEXT: s_mov_b32 s16, s40
43+
; CHECK-NEXT: s_mov_b32 s17, s41
44+
; CHECK-NEXT: s_mov_b32 s18, s42
45+
; CHECK-NEXT: s_mov_b32 s19, s43
46+
; CHECK-NEXT: s_mov_b32 s20, s44
47+
; CHECK-NEXT: s_mov_b32 s21, s45
48+
; CHECK-NEXT: s_mov_b32 s22, s46
49+
; CHECK-NEXT: s_mov_b32 s23, s47
50+
; CHECK-NEXT: s_mov_b32 s24, s48
51+
; CHECK-NEXT: s_mov_b32 s25, s49
52+
; CHECK-NEXT: s_mov_b32 s26, s50
53+
; CHECK-NEXT: s_mov_b32 s27, s51
54+
; CHECK-NEXT: v_mov_b32_e32 v0, s30
55+
; CHECK-NEXT: s_mov_b32 s32, 0
56+
; CHECK-NEXT: s_swappc_b64 s[30:31], s[52:53]
57+
; CHECK-NEXT: v_mov_b64_e32 v[2:3], s[34:35]
58+
; CHECK-NEXT: flat_store_dword v[2:3], v0
59+
; CHECK-NEXT: s_endpgm
60+
%ret = call i32 @callee(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4)
61+
store i32 %ret, ptr %p
62+
ret void
63+
}

0 commit comments

Comments
 (0)