Skip to content

Commit 91dd844

Browse files
authored
Recommit [RISCV] RISCV vector calling convention (2/2) (#79096) (#87736)
Bug fix: Handle RVV return type in calling convention correctly. Return values are handled in a same way as function arguments. One thing to mention is that if a type can be broken down into homogeneous vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}}, it is considered as a vector tuple type and need to be handled by tuple type rule.
1 parent 75244a1 commit 91dd844

File tree

7 files changed

+515
-85
lines changed

7 files changed

+515
-85
lines changed

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
18091809
else if (attr.hasRetAttr(Attribute::ZExt))
18101810
Flags.setZExt();
18111811

1812-
for (unsigned i = 0; i < NumParts; ++i)
1813-
Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
1812+
for (unsigned i = 0; i < NumParts; ++i) {
1813+
ISD::ArgFlagsTy OutFlags = Flags;
1814+
if (NumParts > 1 && i == 0)
1815+
OutFlags.setSplit();
1816+
else if (i == NumParts - 1 && i != 0)
1817+
OutFlags.setSplitEnd();
1818+
1819+
Outs.push_back(
1820+
ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
1821+
}
18141822
}
18151823
}
18161824

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
3434
// Whether this is assigning args for a return.
3535
bool IsRet;
3636

37-
// true if assignArg has been called for a mask argument, false otherwise.
38-
bool AssignedFirstMaskArg = false;
37+
RVVArgDispatcher &RVVDispatcher;
3938

4039
public:
4140
RISCVOutgoingValueAssigner(
42-
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
41+
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
42+
RVVArgDispatcher &RVVDispatcher)
4343
: CallLowering::OutgoingValueAssigner(nullptr),
44-
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
44+
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
45+
RVVDispatcher(RVVDispatcher) {}
4546

4647
bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
4748
CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
5152
const DataLayout &DL = MF.getDataLayout();
5253
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
5354

54-
std::optional<unsigned> FirstMaskArgument;
55-
if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
56-
ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
57-
FirstMaskArgument = ValNo;
58-
AssignedFirstMaskArg = true;
59-
}
60-
6155
if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
6256
LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
63-
*Subtarget.getTargetLowering(), FirstMaskArgument))
57+
*Subtarget.getTargetLowering(), RVVDispatcher))
6458
return true;
6559

6660
StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
181175
// Whether this is assigning args from a return.
182176
bool IsRet;
183177

184-
// true if assignArg has been called for a mask argument, false otherwise.
185-
bool AssignedFirstMaskArg = false;
178+
RVVArgDispatcher &RVVDispatcher;
186179

187180
public:
188181
RISCVIncomingValueAssigner(
189-
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
182+
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
183+
RVVArgDispatcher &RVVDispatcher)
190184
: CallLowering::IncomingValueAssigner(nullptr),
191-
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
185+
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
186+
RVVDispatcher(RVVDispatcher) {}
192187

193188
bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
194189
CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
201196
if (LocVT.isScalableVector())
202197
MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
203198

204-
std::optional<unsigned> FirstMaskArgument;
205-
if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
206-
ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
207-
FirstMaskArgument = ValNo;
208-
AssignedFirstMaskArg = true;
209-
}
210-
211199
if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
212200
LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
213-
*Subtarget.getTargetLowering(), FirstMaskArgument))
201+
*Subtarget.getTargetLowering(), RVVDispatcher))
214202
return true;
215203

216204
StackSize = State.getStackSize();
@@ -420,9 +408,11 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
420408
SmallVector<ArgInfo, 4> SplitRetInfos;
421409
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
422410

411+
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
412+
ArrayRef(F.getReturnType())};
423413
RISCVOutgoingValueAssigner Assigner(
424414
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
425-
/*IsRet=*/true);
415+
/*IsRet=*/true, Dispatcher);
426416
RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
427417
return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
428418
MIRBuilder, CC, F.isVarArg());
@@ -531,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
531521
CallingConv::ID CC = F.getCallingConv();
532522

533523
SmallVector<ArgInfo, 32> SplitArgInfos;
524+
SmallVector<Type *, 4> TypeList;
534525
unsigned Index = 0;
535526
for (auto &Arg : F.args()) {
536527
// Construct the ArgInfo object from destination register and argument type.
@@ -542,12 +533,16 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
542533
// correspondingly and appended to SplitArgInfos.
543534
splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
544535

536+
TypeList.push_back(Arg.getType());
537+
545538
++Index;
546539
}
547540

541+
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
542+
ArrayRef(TypeList)};
548543
RISCVIncomingValueAssigner Assigner(
549544
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
550-
/*IsRet=*/false);
545+
/*IsRet=*/false, Dispatcher);
551546
RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
552547

553548
SmallVector<CCValAssign, 16> ArgLocs;
@@ -585,11 +580,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
585580

586581
SmallVector<ArgInfo, 32> SplitArgInfos;
587582
SmallVector<ISD::OutputArg, 8> Outs;
583+
SmallVector<Type *, 4> TypeList;
588584
for (auto &AInfo : Info.OrigArgs) {
589585
// Handle any required unmerging of split value types from a given VReg into
590586
// physical registers. ArgInfo objects are constructed correspondingly and
591587
// appended to SplitArgInfos.
592588
splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
589+
TypeList.push_back(AInfo.Ty);
593590
}
594591

595592
// TODO: Support tail calls.
@@ -607,9 +604,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
607604
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
608605
Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
609606

607+
RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
608+
ArrayRef(TypeList)};
610609
RISCVOutgoingValueAssigner ArgAssigner(
611610
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
612-
/*IsRet=*/false);
611+
/*IsRet=*/false, ArgDispatcher);
613612
RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
614613
if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
615614
MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +636,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
637636
SmallVector<ArgInfo, 4> SplitRetInfos;
638637
splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
639638

639+
RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
640+
ArrayRef(F.getReturnType())};
640641
RISCVIncomingValueAssigner RetAssigner(
641642
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
642-
/*IsRet=*/true);
643+
/*IsRet=*/true, RetDispatcher);
643644
RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
644645
if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
645646
MIRBuilder, CC, Info.IsVarArg))

0 commit comments

Comments
 (0)