diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index f64ded4f2cf96..6e7b67ded23c8 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType, else if (attr.hasRetAttr(Attribute::ZExt)) Flags.setZExt(); - for (unsigned i = 0; i < NumParts; ++i) - Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0)); + for (unsigned i = 0; i < NumParts; ++i) { + ISD::ArgFlagsTy OutFlags = Flags; + if (NumParts > 1 && i == 0) + OutFlags.setSplit(); + else if (i == NumParts - 1 && i != 0) + OutFlags.setSplitEnd(); + + Outs.push_back( + ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0)); + } } } diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp index 45e19cdea300b..c18892ac62f24 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp @@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner { // Whether this is assigning args for a return. bool IsRet; - // true if assignArg has been called for a mask argument, false otherwise. - bool AssignedFirstMaskArg = false; + RVVArgDispatcher &RVVDispatcher; public: RISCVOutgoingValueAssigner( - RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet) + RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet, + RVVArgDispatcher &RVVDispatcher) : CallLowering::OutgoingValueAssigner(nullptr), - RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {} + RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet), + RVVDispatcher(RVVDispatcher) {} bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, @@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner { const DataLayout &DL = MF.getDataLayout(); const RISCVSubtarget &Subtarget = MF.getSubtarget(); - std::optional FirstMaskArgument; - if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg && - ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) { - FirstMaskArgument = ValNo; - AssignedFirstMaskArg = true; - } - if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT, LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty, - *Subtarget.getTargetLowering(), FirstMaskArgument)) + *Subtarget.getTargetLowering(), RVVDispatcher)) return true; StackSize = State.getStackSize(); @@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner { // Whether this is assigning args from a return. bool IsRet; - // true if assignArg has been called for a mask argument, false otherwise. - bool AssignedFirstMaskArg = false; + RVVArgDispatcher &RVVDispatcher; public: RISCVIncomingValueAssigner( - RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet) + RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet, + RVVArgDispatcher &RVVDispatcher) : CallLowering::IncomingValueAssigner(nullptr), - RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {} + RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet), + RVVDispatcher(RVVDispatcher) {} bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, @@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner { if (LocVT.isScalableVector()) MF.getInfo()->setIsVectorCall(); - std::optional FirstMaskArgument; - if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg && - ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) { - FirstMaskArgument = ValNo; - AssignedFirstMaskArg = true; - } - if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT, LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty, - *Subtarget.getTargetLowering(), FirstMaskArgument)) + *Subtarget.getTargetLowering(), RVVDispatcher)) return true; StackSize = State.getStackSize(); @@ -420,9 +408,11 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder, SmallVector SplitRetInfos; splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC); + RVVArgDispatcher Dispatcher{&MF, getTLI(), + ArrayRef(F.getReturnType())}; RISCVOutgoingValueAssigner Assigner( CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV, - /*IsRet=*/true); + /*IsRet=*/true, Dispatcher); RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret); return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos, MIRBuilder, CC, F.isVarArg()); @@ -531,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, CallingConv::ID CC = F.getCallingConv(); SmallVector SplitArgInfos; + SmallVector TypeList; unsigned Index = 0; for (auto &Arg : F.args()) { // Construct the ArgInfo object from destination register and argument type. @@ -542,12 +533,16 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // correspondingly and appended to SplitArgInfos. splitToValueTypes(AInfo, SplitArgInfos, DL, CC); + TypeList.push_back(Arg.getType()); + ++Index; } + RVVArgDispatcher Dispatcher{&MF, getTLI(), + ArrayRef(TypeList)}; RISCVIncomingValueAssigner Assigner( CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV, - /*IsRet=*/false); + /*IsRet=*/false, Dispatcher); RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo()); SmallVector ArgLocs; @@ -585,11 +580,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, SmallVector SplitArgInfos; SmallVector Outs; + SmallVector TypeList; for (auto &AInfo : Info.OrigArgs) { // Handle any required unmerging of split value types from a given VReg into // physical registers. ArgInfo objects are constructed correspondingly and // appended to SplitArgInfos. splitToValueTypes(AInfo, SplitArgInfos, DL, CC); + TypeList.push_back(AInfo.Ty); } // TODO: Support tail calls. @@ -607,9 +604,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv)); + RVVArgDispatcher ArgDispatcher{&MF, getTLI(), + ArrayRef(TypeList)}; RISCVOutgoingValueAssigner ArgAssigner( CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV, - /*IsRet=*/false); + /*IsRet=*/false, ArgDispatcher); RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call); if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos, MIRBuilder, CC, Info.IsVarArg)) @@ -637,9 +636,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, SmallVector SplitRetInfos; splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC); + RVVArgDispatcher RetDispatcher{&MF, getTLI(), + ArrayRef(F.getReturnType())}; RISCVIncomingValueAssigner RetAssigner( CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV, - /*IsRet=*/true); + /*IsRet=*/true, RetDispatcher); RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call); if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos, MIRBuilder, CC, Info.IsVarArg)) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 5a572002091ff..99054ffd1c4f0 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" @@ -18078,33 +18079,12 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1, return false; } -static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo, - std::optional FirstMaskArgument, - CCState &State, const RISCVTargetLowering &TLI) { - const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT); - if (RC == &RISCV::VRRegClass) { - // Assign the first mask argument to V0. - // This is an interim calling convention and it may be changed in the - // future. - if (FirstMaskArgument && ValNo == *FirstMaskArgument) - return State.AllocateReg(RISCV::V0); - return State.AllocateReg(ArgVRs); - } - if (RC == &RISCV::VRM2RegClass) - return State.AllocateReg(ArgVRM2s); - if (RC == &RISCV::VRM4RegClass) - return State.AllocateReg(ArgVRM4s); - if (RC == &RISCV::VRM8RegClass) - return State.AllocateReg(ArgVRM8s); - llvm_unreachable("Unhandled register class for ValueType"); -} - // Implements the RISC-V calling convention. Returns true upon failure. bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI, - std::optional FirstMaskArgument) { + RVVArgDispatcher &RVVDispatcher) { unsigned XLen = DL.getLargestLegalIntTypeSizeInBits(); assert(XLen == 32 || XLen == 64); MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64; @@ -18273,7 +18253,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, else if (ValVT == MVT::f64 && !UseGPRForF64) Reg = State.AllocateReg(ArgFPR64s); else if (ValVT.isVector()) { - Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI); + Reg = RVVDispatcher.getNextPhysReg(); if (!Reg) { // For return values, the vector must be passed fully via registers or // via the stack. @@ -18359,9 +18339,15 @@ void RISCVTargetLowering::analyzeInputArgs( unsigned NumArgs = Ins.size(); FunctionType *FType = MF.getFunction().getFunctionType(); - std::optional FirstMaskArgument; - if (Subtarget.hasVInstructions()) - FirstMaskArgument = preAssignMask(Ins); + RVVArgDispatcher Dispatcher; + if (IsRet) { + Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)}; + } else { + SmallVector TypeList; + for (const Argument &Arg : MF.getFunction().args()) + TypeList.push_back(Arg.getType()); + Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)}; + } for (unsigned i = 0; i != NumArgs; ++i) { MVT ArgVT = Ins[i].VT; @@ -18376,7 +18362,7 @@ void RISCVTargetLowering::analyzeInputArgs( RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this, - FirstMaskArgument)) { + Dispatcher)) { LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type " << ArgVT << '\n'); llvm_unreachable(nullptr); @@ -18390,9 +18376,13 @@ void RISCVTargetLowering::analyzeOutputArgs( CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const { unsigned NumArgs = Outs.size(); - std::optional FirstMaskArgument; - if (Subtarget.hasVInstructions()) - FirstMaskArgument = preAssignMask(Outs); + SmallVector TypeList; + if (IsRet) + TypeList.push_back(MF.getFunction().getReturnType()); + else if (CLI) + for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs()) + TypeList.push_back(Arg.Ty); + RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)}; for (unsigned i = 0; i != NumArgs; i++) { MVT ArgVT = Outs[i].VT; @@ -18402,7 +18392,7 @@ void RISCVTargetLowering::analyzeOutputArgs( RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this, - FirstMaskArgument)) { + Dispatcher)) { LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type " << ArgVT << "\n"); llvm_unreachable(nullptr); @@ -18583,7 +18573,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI, - std::optional FirstMaskArgument) { + RVVArgDispatcher &RVVDispatcher) { if (LocVT == MVT::i32 || LocVT == MVT::i64) { if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) { State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo)); @@ -18661,13 +18651,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, } if (LocVT.isVector()) { - if (unsigned Reg = - allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) { + MCPhysReg AllocatedVReg = RVVDispatcher.getNextPhysReg(); + if (AllocatedVReg) { // Fixed-length vectors are located in the corresponding scalable-vector // container types. if (ValVT.isFixedLengthVector()) LocVT = TLI.getContainerForFixedLengthVector(LocVT); - State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo)); + State.addLoc( + CCValAssign::getReg(ValNo, ValVT, AllocatedVReg, LocVT, LocInfo)); } else { // Try and pass the address via a "fast" GPR. if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) { @@ -19295,17 +19286,15 @@ bool RISCVTargetLowering::CanLowerReturn( SmallVector RVLocs; CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context); - std::optional FirstMaskArgument; - if (Subtarget.hasVInstructions()) - FirstMaskArgument = preAssignMask(Outs); + RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)}; for (unsigned i = 0, e = Outs.size(); i != e; ++i) { MVT VT = Outs[i].VT; ISD::ArgFlagsTy ArgFlags = Outs[i].Flags; RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full, - ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr, - *this, FirstMaskArgument)) + ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, + nullptr, *this, Dispatcher)) return false; } return true; @@ -21102,6 +21091,181 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const { return Subtarget.getMinimumJumpTableEntries(); } +// Handle single arg such as return value. +template +void RVVArgDispatcher::constructArgInfos(ArrayRef ArgList) { + // This lambda determines whether an array of types are constructed by + // homogeneous vector types. + auto isHomogeneousScalableVectorType = [](ArrayRef ArgList) { + // First, extract the first element in the argument type. + auto It = ArgList.begin(); + MVT FirstArgRegType = It->VT; + + // Return if there is no return or the type needs split. + if (It == ArgList.end() || It->Flags.isSplit()) + return false; + + ++It; + + // Return if this argument type contains only 1 element, or it's not a + // vector type. + if (It == ArgList.end() || !FirstArgRegType.isScalableVector()) + return false; + + // Second, check if the following elements in this argument type are all the + // same. + for (; It != ArgList.end(); ++It) + if (It->Flags.isSplit() || It->VT != FirstArgRegType) + return false; + + return true; + }; + + if (isHomogeneousScalableVectorType(ArgList)) { + // Handle as tuple type + RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false}); + } else { + // Handle as normal vector type + bool FirstVMaskAssigned = false; + for (const auto &OutArg : ArgList) { + MVT RegisterVT = OutArg.VT; + + // Skip non-RVV register type + if (!RegisterVT.isVector()) + continue; + + if (RegisterVT.isFixedLengthVector()) + RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT); + + if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) { + RVVArgInfos.push_back({1, RegisterVT, true}); + FirstVMaskAssigned = true; + continue; + } + + RVVArgInfos.push_back({1, RegisterVT, false}); + } + } +} + +// Handle multiple args. +template <> +void RVVArgDispatcher::constructArgInfos(ArrayRef TypeList) { + const DataLayout &DL = MF->getDataLayout(); + const Function &F = MF->getFunction(); + LLVMContext &Context = F.getContext(); + + bool FirstVMaskAssigned = false; + for (Type *Ty : TypeList) { + StructType *STy = dyn_cast(Ty); + if (STy && STy->containsHomogeneousScalableVectorTypes()) { + Type *ElemTy = STy->getTypeAtIndex(0U); + EVT VT = TLI->getValueType(DL, ElemTy); + MVT RegisterVT = + TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT); + unsigned NumRegs = + TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT); + + RVVArgInfos.push_back( + {NumRegs * STy->getNumElements(), RegisterVT, false}); + } else { + SmallVector ValueVTs; + ComputeValueVTs(*TLI, DL, Ty, ValueVTs); + + for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues; + ++Value) { + EVT VT = ValueVTs[Value]; + MVT RegisterVT = + TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT); + unsigned NumRegs = + TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT); + + // Skip non-RVV register type + if (!RegisterVT.isVector()) + continue; + + if (RegisterVT.isFixedLengthVector()) + RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT); + + if (!FirstVMaskAssigned && + RegisterVT.getVectorElementType() == MVT::i1) { + RVVArgInfos.push_back({1, RegisterVT, true}); + FirstVMaskAssigned = true; + --NumRegs; + } + + RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, {1, RegisterVT, false}); + } + } + } +} + +void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul, + unsigned StartReg) { + assert((StartReg % LMul) == 0 && + "Start register number should be multiple of lmul"); + const MCPhysReg *VRArrays; + switch (LMul) { + default: + report_fatal_error("Invalid lmul"); + case 1: + VRArrays = ArgVRs; + break; + case 2: + VRArrays = ArgVRM2s; + break; + case 4: + VRArrays = ArgVRM4s; + break; + case 8: + VRArrays = ArgVRM8s; + break; + } + + for (unsigned i = 0; i < NF; ++i) + if (StartReg) + AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]); + else + AllocatedPhysRegs.push_back(MCPhysReg()); +} + +/// This function determines if each RVV argument is passed by register, if the +/// argument can be assigned to a VR, then give it a specific register. +/// Otherwise, assign the argument to 0 which is a invalid MCPhysReg. +void RVVArgDispatcher::compute() { + uint32_t AssignedMap = 0; + auto allocate = [&](const RVVArgInfo &ArgInfo) { + // Allocate first vector mask argument to V0. + if (ArgInfo.FirstVMask) { + AllocatedPhysRegs.push_back(RISCV::V0); + return; + } + + unsigned RegsNeeded = divideCeil( + ArgInfo.VT.getSizeInBits().getKnownMinValue(), RISCV::RVVBitsPerBlock); + unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded; + for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs; + StartReg += RegsNeeded) { + uint32_t Map = ((1 << TotalRegsNeeded) - 1) << StartReg; + if ((AssignedMap & Map) == 0) { + allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8); + AssignedMap |= Map; + return; + } + } + + allocatePhysReg(ArgInfo.NF, RegsNeeded, 0); + }; + + for (unsigned i = 0; i < RVVArgInfos.size(); ++i) + allocate(RVVArgInfos[i]); +} + +MCPhysReg RVVArgDispatcher::getNextPhysReg() { + assert(CurIdx < AllocatedPhysRegs.size() && "Index out of range"); + return AllocatedPhysRegs[CurIdx++]; +} + namespace llvm::RISCVVIntrinsicsTable { #define GET_RISCVVIntrinsicsTable_IMPL diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index ace5b3fd2b95b..a2456f2fab66b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -24,6 +24,7 @@ namespace llvm { class InstructionCost; class RISCVSubtarget; struct RISCVRegisterInfo; +class RVVArgDispatcher; namespace RISCVISD { // clang-format off @@ -875,7 +876,7 @@ class RISCVTargetLowering : public TargetLowering { ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI, - std::optional FirstMaskArgument); + RVVArgDispatcher &RVVDispatcher); private: void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo, @@ -1015,19 +1016,71 @@ class RISCVTargetLowering : public TargetLowering { unsigned getMinimumJumpTableEntries() const override; }; +/// As per the spec, the rules for passing vector arguments are as follows: +/// +/// 1. For the first vector mask argument, use v0 to pass it. +/// 2. For vector data arguments or rest vector mask arguments, starting from +/// the v8 register, if a vector register group between v8-v23 that has not been +/// allocated can be found and the first register number is a multiple of LMUL, +/// then allocate this vector register group to the argument and mark these +/// registers as allocated. Otherwise, pass it by reference and are replaced in +/// the argument list with the address. +/// 3. For tuple vector data arguments, starting from the v8 register, if +/// NFIELDS consecutive vector register groups between v8-v23 that have not been +/// allocated can be found and the first register number is a multiple of LMUL, +/// then allocate these vector register groups to the argument and mark these +/// registers as allocated. Otherwise, pass it by reference and are replaced in +/// the argument list with the address. +class RVVArgDispatcher { +public: + static constexpr unsigned NumArgVRs = 16; + + struct RVVArgInfo { + unsigned NF; + MVT VT; + bool FirstVMask = false; + }; + + template + RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI, + ArrayRef ArgList) + : MF(MF), TLI(TLI) { + constructArgInfos(ArgList); + compute(); + } + + RVVArgDispatcher() = default; + + MCPhysReg getNextPhysReg(); + +private: + SmallVector RVVArgInfos; + SmallVector AllocatedPhysRegs; + + const MachineFunction *MF = nullptr; + const RISCVTargetLowering *TLI = nullptr; + + unsigned CurIdx = 0; + + template void constructArgInfos(ArrayRef Ret); + void compute(); + void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1, + unsigned StartReg = 0); +}; + namespace RISCV { bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI, - std::optional FirstMaskArgument); + RVVArgDispatcher &RVVDispatcher); bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI, - std::optional FirstMaskArgument); + RVVArgDispatcher &RVVDispatcher); bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll index 78e8700a9feff..647d3158b6167 100644 --- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll +++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll @@ -162,3 +162,206 @@ define void @caller_tuple_argument({, } %x) } declare void @callee_tuple_argument({, }) + +; %0 -> v8 +; %1 -> v9 +define @case1( %0, %1) { +; CHECK-LABEL: case1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: ret + %a = add %0, %1 + ret %a +} + +; %0 -> v8 +; %1 -> v10-v11 +; %2 -> v9 +define @case2_1( %0, %1, %2) { +; CHECK-LABEL: case2_1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: ret + %a = add %0, %2 + ret %a +} +define @case2_2( %0, %1, %2) { +; CHECK-LABEL: case2_2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, ma +; CHECK-NEXT: vadd.vv v8, v10, v10 +; CHECK-NEXT: ret + %a = add %1, %1 + ret %a +} + +; %0 -> v8 +; %1 -> {v10-v11, v12-v13} +; %2 -> v9 +define @case3_1( %0, {, } %1, %2) { +; CHECK-LABEL: case3_1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: ret + %add = add %0, %2 + ret %add +} +define @case3_2( %0, {, } %1, %2) { +; CHECK-LABEL: case3_2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, ma +; CHECK-NEXT: vadd.vv v8, v10, v12 +; CHECK-NEXT: ret + %a = extractvalue { , } %1, 0 + %b = extractvalue { , } %1, 1 + %add = add %a, %b + ret %add +} + +; %0 -> v8 +; %1 -> {by-ref, by-ref} +; %2 -> v9 +define @case4_1( %0, {, } %1, %2) { +; CHECK-LABEL: case4_1: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a1, vlenb +; CHECK-NEXT: slli a1, a1, 3 +; CHECK-NEXT: add a1, a0, a1 +; CHECK-NEXT: vl8re64.v v8, (a1) +; CHECK-NEXT: vl8re64.v v16, (a0) +; CHECK-NEXT: vsetvli a0, zero, e64, m8, ta, ma +; CHECK-NEXT: vadd.vv v8, v16, v8 +; CHECK-NEXT: ret + %a = extractvalue { , } %1, 0 + %b = extractvalue { , } %1, 1 + %add = add %a, %b + ret %add +} +define @case4_2( %0, {, } %1, %2) { +; CHECK-LABEL: case4_2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: ret + %add = add %0, %2 + ret %add +} + +declare @callee1() +declare void @callee2() +declare void @callee3() +define void @caller() { +; RV32-LABEL: caller: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: .cfi_def_cfa_offset 16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: .cfi_offset ra, -4 +; RV32-NEXT: call callee1 +; RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; RV32-NEXT: vadd.vv v8, v8, v8 +; RV32-NEXT: call callee2 +; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32-NEXT: addi sp, sp, 16 +; RV32-NEXT: ret +; +; RV64-LABEL: caller: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: .cfi_def_cfa_offset 16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: .cfi_offset ra, -8 +; RV64-NEXT: call callee1 +; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; RV64-NEXT: vadd.vv v8, v8, v8 +; RV64-NEXT: call callee2 +; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 16 +; RV64-NEXT: ret + %a = call @callee1() + %add = add %a, %a + call void @callee2( %add) + ret void +} + +declare {, } @callee_tuple() +define void @caller_tuple() { +; RV32-LABEL: caller_tuple: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: .cfi_def_cfa_offset 16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: .cfi_offset ra, -4 +; RV32-NEXT: call callee_tuple +; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; RV32-NEXT: vadd.vv v8, v8, v10 +; RV32-NEXT: call callee3 +; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32-NEXT: addi sp, sp, 16 +; RV32-NEXT: ret +; +; RV64-LABEL: caller_tuple: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: .cfi_def_cfa_offset 16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: .cfi_offset ra, -8 +; RV64-NEXT: call callee_tuple +; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; RV64-NEXT: vadd.vv v8, v8, v10 +; RV64-NEXT: call callee3 +; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 16 +; RV64-NEXT: ret + %a = call {, } @callee_tuple() + %b = extractvalue {, } %a, 0 + %c = extractvalue {, } %a, 1 + %add = add %b, %c + call void @callee3( %add) + ret void +} + +declare {, {, }} @callee_nested() +define void @caller_nested() { +; RV32-LABEL: caller_nested: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: .cfi_def_cfa_offset 16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: .cfi_offset ra, -4 +; RV32-NEXT: call callee_nested +; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; RV32-NEXT: vadd.vv v8, v8, v10 +; RV32-NEXT: vadd.vv v8, v8, v12 +; RV32-NEXT: call callee3 +; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32-NEXT: addi sp, sp, 16 +; RV32-NEXT: ret +; +; RV64-LABEL: caller_nested: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: .cfi_def_cfa_offset 16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: .cfi_offset ra, -8 +; RV64-NEXT: call callee_nested +; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; RV64-NEXT: vadd.vv v8, v8, v10 +; RV64-NEXT: vadd.vv v8, v8, v12 +; RV64-NEXT: call callee3 +; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 16 +; RV64-NEXT: ret + %a = call {, {, }} @callee_nested() + %b = extractvalue {, {, }} %a, 0 + %c = extractvalue {, {, }} %a, 1 + %c0 = extractvalue {, } %c, 0 + %c1 = extractvalue {, } %c, 1 + %add0 = add %b, %c0 + %add1 = add %add0, %c1 + call void @callee3( %add1) + ret void +} diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll index a320aecc6fce4..6a712080fda74 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll @@ -18,10 +18,10 @@ define {, } @vector_deinterleave_load_nxv16i ; CHECK-NEXT: vmerge.vim v14, v10, 1, v0 ; CHECK-NEXT: vmv1r.v v0, v8 ; CHECK-NEXT: vmerge.vim v12, v10, 1, v0 -; CHECK-NEXT: vnsrl.wi v8, v12, 0 -; CHECK-NEXT: vmsne.vi v0, v8, 0 -; CHECK-NEXT: vnsrl.wi v10, v12, 8 +; CHECK-NEXT: vnsrl.wi v10, v12, 0 ; CHECK-NEXT: vmsne.vi v8, v10, 0 +; CHECK-NEXT: vnsrl.wi v10, v12, 8 +; CHECK-NEXT: vmsne.vi v9, v10, 0 ; CHECK-NEXT: ret %vec = load , ptr %p %retval = call {, } @llvm.experimental.vector.deinterleave2.nxv32i1( %vec) diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll index ef4baf34d23f0..d98597fabcd95 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll @@ -8,18 +8,18 @@ define {, } @vector_deinterleave_nxv16i1_nxv ; CHECK-LABEL: vector_deinterleave_nxv16i1_nxv32i1: ; CHECK: # %bb.0: ; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma -; CHECK-NEXT: vmv.v.i v10, 0 -; CHECK-NEXT: vmerge.vim v8, v10, 1, v0 +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vmerge.vim v12, v8, 1, v0 ; CHECK-NEXT: csrr a0, vlenb ; CHECK-NEXT: srli a0, a0, 2 ; CHECK-NEXT: vsetvli a1, zero, e8, mf2, ta, ma ; CHECK-NEXT: vslidedown.vx v0, v0, a0 ; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma -; CHECK-NEXT: vmerge.vim v10, v10, 1, v0 -; CHECK-NEXT: vnsrl.wi v12, v8, 0 -; CHECK-NEXT: vmsne.vi v0, v12, 0 -; CHECK-NEXT: vnsrl.wi v12, v8, 8 -; CHECK-NEXT: vmsne.vi v8, v12, 0 +; CHECK-NEXT: vmerge.vim v14, v8, 1, v0 +; CHECK-NEXT: vnsrl.wi v10, v12, 0 +; CHECK-NEXT: vmsne.vi v8, v10, 0 +; CHECK-NEXT: vnsrl.wi v10, v12, 8 +; CHECK-NEXT: vmsne.vi v9, v10, 0 ; CHECK-NEXT: ret %retval = call {, } @llvm.experimental.vector.deinterleave2.nxv32i1( %vec) ret {, } %retval @@ -102,12 +102,13 @@ define {, } @vector_deinterleave_nxv64i1_nxv ; CHECK-NEXT: vsetvli a0, zero, e8, m4, ta, ma ; CHECK-NEXT: vnsrl.wi v28, v8, 0 ; CHECK-NEXT: vsetvli a0, zero, e8, m8, ta, ma -; CHECK-NEXT: vmsne.vi v0, v24, 0 +; CHECK-NEXT: vmsne.vi v7, v24, 0 ; CHECK-NEXT: vsetvli a0, zero, e8, m4, ta, ma ; CHECK-NEXT: vnsrl.wi v24, v16, 8 ; CHECK-NEXT: vnsrl.wi v28, v8, 8 ; CHECK-NEXT: vsetvli a0, zero, e8, m8, ta, ma -; CHECK-NEXT: vmsne.vi v8, v24, 0 +; CHECK-NEXT: vmsne.vi v9, v24, 0 +; CHECK-NEXT: vmv1r.v v8, v7 ; CHECK-NEXT: ret %retval = call {, } @llvm.experimental.vector.deinterleave2.nxv128i1( %vec) ret {, } %retval