Skip to content

Commit 1f2cde3

Browse files
committed
[RISCV] RISCV vector calling convention (2/2)
This commit handles vector arguments/return for function definition/call, the new class RVVArgDispatcher is added for doing all vector register assignment including mask types, data types as well as tuple types. It precomputes the register number for each argument as per https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#standard-vector-calling-convention-variant and it's passed to calling convention function to handle all vector arguments. Depends on: #78550
1 parent 546dc22 commit 1f2cde3

File tree

6 files changed

+322
-83
lines changed

6 files changed

+322
-83
lines changed

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
3434
// Whether this is assigning args for a return.
3535
bool IsRet;
3636

37-
// true if assignArg has been called for a mask argument, false otherwise.
38-
bool AssignedFirstMaskArg = false;
37+
RVVArgDispatcher &RVVDispatcher;
3938

4039
public:
4140
RISCVOutgoingValueAssigner(
42-
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
41+
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
42+
RVVArgDispatcher &RVVDispatcher)
4343
: CallLowering::OutgoingValueAssigner(nullptr),
44-
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
44+
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
45+
RVVDispatcher(RVVDispatcher) {}
4546

4647
bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
4748
CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
5152
const DataLayout &DL = MF.getDataLayout();
5253
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
5354

54-
std::optional<unsigned> FirstMaskArgument;
55-
if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
56-
ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
57-
FirstMaskArgument = ValNo;
58-
AssignedFirstMaskArg = true;
59-
}
60-
6155
if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
6256
LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
63-
*Subtarget.getTargetLowering(), FirstMaskArgument))
57+
*Subtarget.getTargetLowering(), RVVDispatcher))
6458
return true;
6559

6660
StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
181175
// Whether this is assigning args from a return.
182176
bool IsRet;
183177

184-
// true if assignArg has been called for a mask argument, false otherwise.
185-
bool AssignedFirstMaskArg = false;
178+
RVVArgDispatcher &RVVDispatcher;
186179

187180
public:
188181
RISCVIncomingValueAssigner(
189-
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
182+
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
183+
RVVArgDispatcher &RVVDispatcher)
190184
: CallLowering::IncomingValueAssigner(nullptr),
191-
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
185+
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
186+
RVVDispatcher(RVVDispatcher) {}
192187

193188
bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
194189
CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
201196
if (LocVT.isScalableVector())
202197
MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
203198

204-
std::optional<unsigned> FirstMaskArgument;
205-
if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
206-
ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
207-
FirstMaskArgument = ValNo;
208-
AssignedFirstMaskArg = true;
209-
}
210-
211199
if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
212200
LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
213-
*Subtarget.getTargetLowering(), FirstMaskArgument))
201+
*Subtarget.getTargetLowering(), RVVDispatcher))
214202
return true;
215203

216204
StackSize = State.getStackSize();
@@ -420,9 +408,11 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
420408
SmallVector<ArgInfo, 4> SplitRetInfos;
421409
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
422410

411+
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
412+
F.getReturnType()};
423413
RISCVOutgoingValueAssigner Assigner(
424414
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
425-
/*IsRet=*/true);
415+
/*IsRet=*/true, Dispatcher);
426416
RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
427417
return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
428418
MIRBuilder, CC, F.isVarArg());
@@ -531,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
531521
CallingConv::ID CC = F.getCallingConv();
532522

533523
SmallVector<ArgInfo, 32> SplitArgInfos;
524+
SmallVector<Type *, 4> TypeList;
534525
unsigned Index = 0;
535526
for (auto &Arg : F.args()) {
536527
// Construct the ArgInfo object from destination register and argument type.
@@ -542,12 +533,15 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
542533
// correspondingly and appended to SplitArgInfos.
543534
splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
544535

536+
TypeList.push_back(Arg.getType());
537+
545538
++Index;
546539
}
547540

541+
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
548542
RISCVIncomingValueAssigner Assigner(
549543
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
550-
/*IsRet=*/false);
544+
/*IsRet=*/false, Dispatcher);
551545
RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
552546

553547
SmallVector<CCValAssign, 16> ArgLocs;
@@ -585,11 +579,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
585579

586580
SmallVector<ArgInfo, 32> SplitArgInfos;
587581
SmallVector<ISD::OutputArg, 8> Outs;
582+
SmallVector<Type *, 4> TypeList;
588583
for (auto &AInfo : Info.OrigArgs) {
589584
// Handle any required unmerging of split value types from a given VReg into
590585
// physical registers. ArgInfo objects are constructed correspondingly and
591586
// appended to SplitArgInfos.
592587
splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
588+
TypeList.push_back(AInfo.Ty);
593589
}
594590

595591
// TODO: Support tail calls.
@@ -607,9 +603,10 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
607603
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
608604
Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
609605

606+
RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
610607
RISCVOutgoingValueAssigner ArgAssigner(
611608
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
612-
/*IsRet=*/false);
609+
/*IsRet=*/false, ArgDispatcher);
613610
RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
614611
if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
615612
MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +634,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
637634
SmallVector<ArgInfo, 4> SplitRetInfos;
638635
splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
639636

637+
RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
638+
F.getReturnType()};
640639
RISCVIncomingValueAssigner RetAssigner(
641640
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
642-
/*IsRet=*/true);
641+
/*IsRet=*/true, RetDispatcher);
643642
RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
644643
if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
645644
MIRBuilder, CC, Info.IsVarArg))

0 commit comments

Comments
 (0)