Skip to content

Commit 3e1317f

Browse files
committed
[RISCV] Support extraction of misaligned subvectors
This patch extends the support for RVV EXTRACT_SUBVECTOR to cover those which don't align to a vector register boundary. It accomplishes this by extracting the nearest register-sized subvector (a subregister operation), then sliding the vector down with VSLIDEDOWN and extracting the subvector from the first position (a COPY operation). Since this procedure involves the use of VSCALE and multiplication, the handling of such operations is done during lowering to simplify the implementation and make use of DAG combining. This necessitated moving some helper functions from RISCVISelDAGToDAG to RISCVTargetLowering. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D96959
1 parent 9aa20ca commit 3e1317f

File tree

4 files changed

+375
-135
lines changed

4 files changed

+375
-135
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 36 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "RISCVISelDAGToDAG.h"
1414
#include "MCTargetDesc/RISCVMCTargetDesc.h"
1515
#include "MCTargetDesc/RISCVMatInt.h"
16+
#include "RISCVISelLowering.h"
1617
#include "llvm/CodeGen/MachineFrameInfo.h"
1718
#include "llvm/IR/IntrinsicsRISCV.h"
1819
#include "llvm/Support/Alignment.h"
@@ -62,64 +63,6 @@ static SDNode *selectImm(SelectionDAG *CurDAG, const SDLoc &DL, int64_t Imm,
6263
return Result;
6364
}
6465

65-
static RISCVVLMUL getLMUL(MVT VT) {
66-
switch (VT.getSizeInBits().getKnownMinValue() / 8) {
67-
default:
68-
llvm_unreachable("Invalid LMUL.");
69-
case 1:
70-
return RISCVVLMUL::LMUL_F8;
71-
case 2:
72-
return RISCVVLMUL::LMUL_F4;
73-
case 4:
74-
return RISCVVLMUL::LMUL_F2;
75-
case 8:
76-
return RISCVVLMUL::LMUL_1;
77-
case 16:
78-
return RISCVVLMUL::LMUL_2;
79-
case 32:
80-
return RISCVVLMUL::LMUL_4;
81-
case 64:
82-
return RISCVVLMUL::LMUL_8;
83-
}
84-
}
85-
86-
static unsigned getRegClassIDForLMUL(RISCVVLMUL LMul) {
87-
switch (LMul) {
88-
default:
89-
llvm_unreachable("Invalid LMUL.");
90-
case RISCVVLMUL::LMUL_F8:
91-
case RISCVVLMUL::LMUL_F4:
92-
case RISCVVLMUL::LMUL_F2:
93-
case RISCVVLMUL::LMUL_1:
94-
return RISCV::VRRegClassID;
95-
case RISCVVLMUL::LMUL_2:
96-
return RISCV::VRM2RegClassID;
97-
case RISCVVLMUL::LMUL_4:
98-
return RISCV::VRM4RegClassID;
99-
case RISCVVLMUL::LMUL_8:
100-
return RISCV::VRM8RegClassID;
101-
}
102-
}
103-
104-
static unsigned getSubregIndexByMVT(MVT VT, unsigned Index) {
105-
RISCVVLMUL LMUL = getLMUL(VT);
106-
if (LMUL == RISCVVLMUL::LMUL_F8 || LMUL == RISCVVLMUL::LMUL_F4 ||
107-
LMUL == RISCVVLMUL::LMUL_F2 || LMUL == RISCVVLMUL::LMUL_1) {
108-
static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
109-
"Unexpected subreg numbering");
110-
return RISCV::sub_vrm1_0 + Index;
111-
} else if (LMUL == RISCVVLMUL::LMUL_2) {
112-
static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
113-
"Unexpected subreg numbering");
114-
return RISCV::sub_vrm2_0 + Index;
115-
} else if (LMUL == RISCVVLMUL::LMUL_4) {
116-
static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
117-
"Unexpected subreg numbering");
118-
return RISCV::sub_vrm4_0 + Index;
119-
}
120-
llvm_unreachable("Invalid vector type.");
121-
}
122-
12366
static SDValue createTupleImpl(SelectionDAG &CurDAG, ArrayRef<SDValue> Regs,
12467
unsigned RegClassID, unsigned SubReg0) {
12568
assert(Regs.size() >= 2 && Regs.size() <= 8);
@@ -187,7 +130,7 @@ void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, bool IsMasked,
187130
MVT VT = Node->getSimpleValueType(0);
188131
unsigned ScalarSize = VT.getScalarSizeInBits();
189132
MVT XLenVT = Subtarget->getXLenVT();
190-
RISCVVLMUL LMUL = getLMUL(VT);
133+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
191134
SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
192135
unsigned CurOp = 2;
193136
SmallVector<SDValue, 7> Operands;
@@ -218,10 +161,11 @@ void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, bool IsMasked,
218161
CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
219162

220163
SDValue SuperReg = SDValue(Load, 0);
221-
for (unsigned I = 0; I < NF; ++I)
164+
for (unsigned I = 0; I < NF; ++I) {
165+
unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
222166
ReplaceUses(SDValue(Node, I),
223-
CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
224-
VT, SuperReg));
167+
CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
168+
}
225169

226170
ReplaceUses(SDValue(Node, NF), SDValue(Load, 1));
227171
CurDAG->RemoveDeadNode(Node);
@@ -233,7 +177,7 @@ void RISCVDAGToDAGISel::selectVLSEGFF(SDNode *Node, bool IsMasked) {
233177
MVT VT = Node->getSimpleValueType(0);
234178
MVT XLenVT = Subtarget->getXLenVT();
235179
unsigned ScalarSize = VT.getScalarSizeInBits();
236-
RISCVVLMUL LMUL = getLMUL(VT);
180+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
237181
SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
238182

239183
unsigned CurOp = 2;
@@ -265,10 +209,11 @@ void RISCVDAGToDAGISel::selectVLSEGFF(SDNode *Node, bool IsMasked) {
265209
CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
266210

267211
SDValue SuperReg = SDValue(Load, 0);
268-
for (unsigned I = 0; I < NF; ++I)
212+
for (unsigned I = 0; I < NF; ++I) {
213+
unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
269214
ReplaceUses(SDValue(Node, I),
270-
CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
271-
VT, SuperReg));
215+
CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
216+
}
272217

273218
ReplaceUses(SDValue(Node, NF), SDValue(ReadVL, 0)); // VL
274219
ReplaceUses(SDValue(Node, NF + 1), SDValue(Load, 1)); // Chain
@@ -282,7 +227,7 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
282227
MVT VT = Node->getSimpleValueType(0);
283228
unsigned ScalarSize = VT.getScalarSizeInBits();
284229
MVT XLenVT = Subtarget->getXLenVT();
285-
RISCVVLMUL LMUL = getLMUL(VT);
230+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
286231
SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
287232
unsigned CurOp = 2;
288233
SmallVector<SDValue, 7> Operands;
@@ -307,7 +252,7 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
307252
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
308253
"Element count mismatch");
309254

310-
RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
255+
RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
311256
unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
312257
const RISCV::VLXSEGPseudo *P = RISCV::getVLXSEGPseudo(
313258
NF, IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -319,10 +264,11 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
319264
CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
320265

321266
SDValue SuperReg = SDValue(Load, 0);
322-
for (unsigned I = 0; I < NF; ++I)
267+
for (unsigned I = 0; I < NF; ++I) {
268+
unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
323269
ReplaceUses(SDValue(Node, I),
324-
CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
325-
VT, SuperReg));
270+
CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
271+
}
326272

327273
ReplaceUses(SDValue(Node, NF), SDValue(Load, 1));
328274
CurDAG->RemoveDeadNode(Node);
@@ -339,7 +285,7 @@ void RISCVDAGToDAGISel::selectVSSEG(SDNode *Node, bool IsMasked,
339285
MVT VT = Node->getOperand(2)->getSimpleValueType(0);
340286
unsigned ScalarSize = VT.getScalarSizeInBits();
341287
MVT XLenVT = Subtarget->getXLenVT();
342-
RISCVVLMUL LMUL = getLMUL(VT);
288+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
343289
SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
344290
SmallVector<SDValue, 8> Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF);
345291
SDValue StoreVal = createTuple(*CurDAG, Regs, NF, LMUL);
@@ -376,7 +322,7 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
376322
MVT VT = Node->getOperand(2)->getSimpleValueType(0);
377323
unsigned ScalarSize = VT.getScalarSizeInBits();
378324
MVT XLenVT = Subtarget->getXLenVT();
379-
RISCVVLMUL LMUL = getLMUL(VT);
325+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
380326
SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
381327
SmallVector<SDValue, 7> Operands;
382328
SmallVector<SDValue, 8> Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF);
@@ -397,7 +343,7 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
397343
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
398344
"Element count mismatch");
399345

400-
RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
346+
RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
401347
unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
402348
const RISCV::VSXSEGPseudo *P = RISCV::getVSXSEGPseudo(
403349
NF, IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -411,47 +357,6 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
411357
ReplaceNode(Node, Store);
412358
}
413359

414-
static unsigned getRegClassIDForVecVT(MVT VT) {
415-
if (VT.getVectorElementType() == MVT::i1)
416-
return RISCV::VRRegClassID;
417-
return getRegClassIDForLMUL(getLMUL(VT));
418-
}
419-
420-
// Attempt to decompose a subvector insert/extract between VecVT and
421-
// SubVecVT via subregister indices. Returns the subregister index that
422-
// can perform the subvector insert/extract with the given element index, as
423-
// well as the index corresponding to any leftover subvectors that must be
424-
// further inserted/extracted within the register class for SubVecVT.
425-
static std::pair<unsigned, unsigned>
426-
decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT,
427-
unsigned InsertExtractIdx,
428-
const RISCVRegisterInfo *TRI) {
429-
static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID &&
430-
RISCV::VRM4RegClassID > RISCV::VRM2RegClassID &&
431-
RISCV::VRM2RegClassID > RISCV::VRRegClassID),
432-
"Register classes not ordered");
433-
unsigned VecRegClassID = getRegClassIDForVecVT(VecVT);
434-
unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT);
435-
// Try to compose a subregister index that takes us from the incoming
436-
// LMUL>1 register class down to the outgoing one. At each step we half
437-
// the LMUL:
438-
// nxv16i32@12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0
439-
// Note that this is not guaranteed to find a subregister index, such as
440-
// when we are extracting from one VR type to another.
441-
unsigned SubRegIdx = RISCV::NoSubRegister;
442-
for (const unsigned RCID :
443-
{RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID})
444-
if (VecRegClassID > RCID && SubRegClassID <= RCID) {
445-
VecVT = VecVT.getHalfNumVectorElementsVT();
446-
bool IsHi =
447-
InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue();
448-
SubRegIdx = TRI->composeSubRegIndices(SubRegIdx,
449-
getSubregIndexByMVT(VecVT, IsHi));
450-
if (IsHi)
451-
InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue();
452-
}
453-
return {SubRegIdx, InsertExtractIdx};
454-
}
455360

456361
void RISCVDAGToDAGISel::Select(SDNode *Node) {
457362
// If we have a custom node, we have already selected.
@@ -726,8 +631,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
726631
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
727632
"Element count mismatch");
728633

729-
RISCVVLMUL LMUL = getLMUL(VT);
730-
RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
634+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
635+
RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
731636
unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
732637
const RISCV::VLX_VSXPseudo *P = RISCV::getVLXPseudo(
733638
IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -855,8 +760,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
855760
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
856761
"Element count mismatch");
857762

858-
RISCVVLMUL LMUL = getLMUL(VT);
859-
RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
763+
RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
764+
RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
860765
unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
861766
const RISCV::VLX_VSXPseudo *P = RISCV::getVSXPseudo(
862767
IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -895,7 +800,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
895800
// For now, keep the two paths separate.
896801
if (VT.isScalableVector() && SubVecVT.isScalableVector()) {
897802
bool IsFullVecReg = false;
898-
switch (getLMUL(SubVecVT)) {
803+
switch (RISCVTargetLowering::getLMUL(SubVecVT)) {
899804
default:
900805
break;
901806
case RISCVVLMUL::LMUL_1:
@@ -915,10 +820,11 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
915820
const auto *TRI = Subtarget->getRegisterInfo();
916821
unsigned SubRegIdx;
917822
std::tie(SubRegIdx, Idx) =
918-
decomposeSubvectorInsertExtractToSubRegs(VT, SubVecVT, Idx, TRI);
823+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
824+
VT, SubVecVT, Idx, TRI);
919825

920826
// If the Idx hasn't been completely eliminated then this is a subvector
921-
// extract which doesn't naturally align to a vector register. These must
827+
// insert which doesn't naturally align to a vector register. These must
922828
// be handled using instructions to manipulate the vector registers.
923829
if (Idx != 0)
924830
break;
@@ -936,7 +842,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
936842
if (!Node->getOperand(0).isUndef())
937843
break;
938844

939-
unsigned RegClassID = getRegClassIDForVecVT(VT);
845+
unsigned RegClassID = RISCVTargetLowering::getRegClassIDForVecVT(VT);
940846

941847
SDValue RC =
942848
CurDAG->getTargetConstant(RegClassID, DL, Subtarget->getXLenVT());
@@ -961,7 +867,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
961867
const auto *TRI = Subtarget->getRegisterInfo();
962868
unsigned SubRegIdx;
963869
std::tie(SubRegIdx, Idx) =
964-
decomposeSubvectorInsertExtractToSubRegs(InVT, VT, Idx, TRI);
870+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
871+
InVT, VT, Idx, TRI);
965872

966873
// If the Idx hasn't been completely eliminated then this is a subvector
967874
// extract which doesn't naturally align to a vector register. These must
@@ -972,8 +879,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
972879
// If we haven't set a SubRegIdx, then we must be going between LMUL<=1
973880
// types (VR -> VR). This can be done as a copy.
974881
if (SubRegIdx == RISCV::NoSubRegister) {
975-
unsigned InRegClassID = getRegClassIDForVecVT(InVT);
976-
assert(getRegClassIDForVecVT(VT) == RISCV::VRRegClassID &&
882+
unsigned InRegClassID =
883+
RISCVTargetLowering::getRegClassIDForVecVT(InVT);
884+
assert(RISCVTargetLowering::getRegClassIDForVecVT(VT) ==
885+
RISCV::VRRegClassID &&
977886
InRegClassID == RISCV::VRRegClassID &&
978887
"Unexpected subvector extraction");
979888
SDValue RC =
@@ -993,7 +902,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
993902
if (Idx != 0)
994903
break;
995904

996-
unsigned InRegClassID = getRegClassIDForVecVT(InVT);
905+
unsigned InRegClassID = RISCVTargetLowering::getRegClassIDForVecVT(InVT);
997906

998907
SDValue RC =
999908
CurDAG->getTargetConstant(InRegClassID, DL, Subtarget->getXLenVT());

0 commit comments

Comments
 (0)