Skip to content

Commit b788713

Browse files
committed
Share logic from extract_vector_elt
1 parent e3534d3 commit b788713

File tree

2 files changed

+47
-56
lines changed

2 files changed

+47
-56
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7466,6 +7466,32 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
74667466
return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
74677467
}
74687468

7469+
// Given a scalable vector type and an index into it, returns the type for the
7470+
// smallest subvector that the index fits in. This can be used to reduce LMUL
7471+
// for operations like vslidedown.
7472+
//
7473+
// E.g. With Zvl128b, index 3 in a nxv4i32 fits within the first nxv2i32.
7474+
static std::optional<MVT>
7475+
getSmallestVTForIndex(MVT VecVT, unsigned MaxIdx, SDLoc DL, SelectionDAG &DAG,
7476+
const RISCVSubtarget &Subtarget) {
7477+
assert(VecVT.isScalableVector());
7478+
const unsigned EltSize = VecVT.getScalarSizeInBits();
7479+
const unsigned VectorBitsMin = Subtarget.getRealMinVLen();
7480+
const unsigned MinVLMAX = VectorBitsMin / EltSize;
7481+
MVT SmallerVT;
7482+
if (MaxIdx < MinVLMAX)
7483+
SmallerVT = getLMUL1VT(VecVT);
7484+
else if (MaxIdx < MinVLMAX * 2)
7485+
SmallerVT = getLMUL1VT(VecVT).getDoubleNumVectorElementsVT();
7486+
else if (MaxIdx < MinVLMAX * 4)
7487+
SmallerVT = getLMUL1VT(VecVT)
7488+
.getDoubleNumVectorElementsVT()
7489+
.getDoubleNumVectorElementsVT();
7490+
if (!SmallerVT.isValid() || !VecVT.bitsGT(SmallerVT))
7491+
return std::nullopt;
7492+
return SmallerVT;
7493+
}
7494+
74697495
// Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
74707496
// extract the first element: (extractelt (slidedown vec, idx), 0). For integer
74717497
// types this is done using VMV_X_S to allow us to glean information about the
@@ -7554,21 +7580,9 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
75547580
if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx))
75557581
MaxIdx = IdxC->getZExtValue();
75567582
if (MaxIdx) {
7557-
const unsigned EltSize = ContainerVT.getScalarSizeInBits();
7558-
const unsigned VectorBitsMin = Subtarget.getRealMinVLen();
7559-
const unsigned MinVLMAX = VectorBitsMin/EltSize;
7560-
MVT SmallerVT;
7561-
if (*MaxIdx < MinVLMAX)
7562-
SmallerVT = getLMUL1VT(ContainerVT);
7563-
else if (*MaxIdx < MinVLMAX * 2)
7564-
SmallerVT = getLMUL1VT(ContainerVT)
7565-
.getDoubleNumVectorElementsVT();
7566-
else if (*MaxIdx < MinVLMAX * 4)
7567-
SmallerVT = getLMUL1VT(ContainerVT)
7568-
.getDoubleNumVectorElementsVT()
7569-
.getDoubleNumVectorElementsVT();
7570-
if (SmallerVT.isValid() && ContainerVT.bitsGT(SmallerVT)) {
7571-
ContainerVT = SmallerVT;
7583+
if (auto SmallerVT =
7584+
getSmallestVTForIndex(ContainerVT, *MaxIdx, DL, DAG, Subtarget)) {
7585+
ContainerVT = *SmallerVT;
75727586
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
75737587
DAG.getConstant(0, DL, XLenVT));
75747588
}
@@ -8752,37 +8766,14 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
87528766
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
87538767
}
87548768

8755-
// The minimum number of elements for a scalable vector type, e.g. nxv1i32
8756-
// is not legal on Zve32x.
8757-
const uint64_t MinLegalNumElts =
8758-
RISCV::RVVBitsPerBlock / Subtarget.getELen();
8759-
const uint64_t MinVscale =
8760-
Subtarget.getRealMinVLen() / RISCV::RVVBitsPerBlock;
8761-
8762-
// Even if we don't know the exact subregister the subvector is going to
8763-
// reside in, we know that the subvector is located within the first N bits
8764-
// of Vec:
8765-
//
8766-
// N = (OrigIdx + SubVecVT.getVectorNumElements()) * EltSizeInBits
8767-
// = MinVscale * MinEltsNeeded * EltSizeInBits
8768-
//
8769-
// From this we can work out the smallest type that contains everything we
8770-
// need to extract, <vscale x MinEltsNeeded x Elt>
8771-
uint64_t MinEltsNeeded =
8772-
(OrigIdx + SubVecVT.getVectorNumElements()) / MinVscale;
8773-
8774-
// Round up the number of elements so it's a valid power of 2 scalable
8775-
// vector type, and make sure it's not less than smallest legal vector type.
8776-
MinEltsNeeded = std::max(MinLegalNumElts, PowerOf2Ceil(MinEltsNeeded));
8777-
8778-
assert(MinEltsNeeded <= ContainerVT.getVectorMinNumElements());
8779-
8780-
// Shrink down Vec so we're performing the slidedown on the smallest
8781-
// possible type.
8782-
ContainerVT = MVT::getScalableVectorVT(ContainerVT.getVectorElementType(),
8783-
MinEltsNeeded);
8784-
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
8785-
DAG.getVectorIdxConstant(0, DL));
8769+
// Shrink down Vec so we're performing the slidedown a smaller LMUL.
8770+
unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1;
8771+
if (auto ShrunkVT =
8772+
getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) {
8773+
ContainerVT = *ShrunkVT;
8774+
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
8775+
DAG.getVectorIdxConstant(0, DL));
8776+
}
87868777

87878778
SDValue Mask =
87888779
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract-subvector.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ define void @extract_v2i32_v8i32_2(ptr %x, ptr %y) {
113113
; CHECK: # %bb.0:
114114
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
115115
; CHECK-NEXT: vle32.v v8, (a0)
116-
; CHECK-NEXT: vsetivli zero, 2, e32, m2, ta, ma
116+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
117117
; CHECK-NEXT: vslidedown.vi v8, v8, 2
118118
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
119119
; CHECK-NEXT: vse32.v v8, (a1)
@@ -171,7 +171,7 @@ define void @extract_v2i32_nxv16i32_0(<vscale x 16 x i32> %x, ptr %y) {
171171
define void @extract_v2i32_nxv16i32_2(<vscale x 16 x i32> %x, ptr %y) {
172172
; CHECK-LABEL: extract_v2i32_nxv16i32_2:
173173
; CHECK: # %bb.0:
174-
; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
174+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
175175
; CHECK-NEXT: vslidedown.vi v8, v8, 2
176176
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
177177
; CHECK-NEXT: vse32.v v8, (a0)
@@ -184,7 +184,7 @@ define void @extract_v2i32_nxv16i32_2(<vscale x 16 x i32> %x, ptr %y) {
184184
define void @extract_v2i32_nxv16i32_4(<vscale x 16 x i32> %x, ptr %y) {
185185
; CHECK-LABEL: extract_v2i32_nxv16i32_4:
186186
; CHECK: # %bb.0:
187-
; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
187+
; CHECK-NEXT: vsetivli zero, 2, e32, m2, ta, ma
188188
; CHECK-NEXT: vslidedown.vi v8, v8, 4
189189
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
190190
; CHECK-NEXT: vse32.v v8, (a0)
@@ -197,7 +197,7 @@ define void @extract_v2i32_nxv16i32_4(<vscale x 16 x i32> %x, ptr %y) {
197197
define void @extract_v2i32_nxv16i32_6(<vscale x 16 x i32> %x, ptr %y) {
198198
; CHECK-LABEL: extract_v2i32_nxv16i32_6:
199199
; CHECK: # %bb.0:
200-
; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
200+
; CHECK-NEXT: vsetivli zero, 2, e32, m2, ta, ma
201201
; CHECK-NEXT: vslidedown.vi v8, v8, 6
202202
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
203203
; CHECK-NEXT: vse32.v v8, (a0)
@@ -210,7 +210,7 @@ define void @extract_v2i32_nxv16i32_6(<vscale x 16 x i32> %x, ptr %y) {
210210
define void @extract_v2i32_nxv16i32_8(<vscale x 16 x i32> %x, ptr %y) {
211211
; CHECK-LABEL: extract_v2i32_nxv16i32_8:
212212
; CHECK: # %bb.0:
213-
; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
213+
; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
214214
; CHECK-NEXT: vslidedown.vi v8, v8, 8
215215
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
216216
; CHECK-NEXT: vse32.v v8, (a0)
@@ -273,7 +273,7 @@ define void @extract_v2i8_nxv2i8_6(<vscale x 2 x i8> %x, ptr %y) {
273273
define void @extract_v8i32_nxv16i32_8(<vscale x 16 x i32> %x, ptr %y) {
274274
; CHECK-LABEL: extract_v8i32_nxv16i32_8:
275275
; CHECK: # %bb.0:
276-
; CHECK-NEXT: vsetivli zero, 8, e32, m8, ta, ma
276+
; CHECK-NEXT: vsetivli zero, 8, e32, m4, ta, ma
277277
; CHECK-NEXT: vslidedown.vi v8, v8, 8
278278
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
279279
; CHECK-NEXT: vse32.v v8, (a0)
@@ -437,7 +437,7 @@ define void @extract_v2i1_v64i1_2(ptr %x, ptr %y) {
437437
; CHECK-NEXT: vlm.v v0, (a0)
438438
; CHECK-NEXT: vmv.v.i v8, 0
439439
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
440-
; CHECK-NEXT: vsetivli zero, 2, e8, m4, ta, ma
440+
; CHECK-NEXT: vsetivli zero, 2, e8, m1, ta, ma
441441
; CHECK-NEXT: vslidedown.vi v8, v8, 2
442442
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
443443
; CHECK-NEXT: vmsne.vi v0, v8, 0
@@ -555,7 +555,7 @@ define void @extract_v2i1_nxv64i1_2(<vscale x 64 x i1> %x, ptr %y) {
555555
; CHECK-NEXT: vsetvli a1, zero, e8, m8, ta, ma
556556
; CHECK-NEXT: vmv.v.i v8, 0
557557
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
558-
; CHECK-NEXT: vsetivli zero, 2, e8, m8, ta, ma
558+
; CHECK-NEXT: vsetivli zero, 2, e8, m1, ta, ma
559559
; CHECK-NEXT: vslidedown.vi v8, v8, 2
560560
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
561561
; CHECK-NEXT: vmsne.vi v0, v8, 0
@@ -581,7 +581,7 @@ define void @extract_v2i1_nxv64i1_42(<vscale x 64 x i1> %x, ptr %y) {
581581
; CHECK-NEXT: vmv.v.i v8, 0
582582
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
583583
; CHECK-NEXT: li a1, 42
584-
; CHECK-NEXT: vsetivli zero, 2, e8, m8, ta, ma
584+
; CHECK-NEXT: vsetivli zero, 2, e8, m4, ta, ma
585585
; CHECK-NEXT: vslidedown.vx v8, v8, a1
586586
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
587587
; CHECK-NEXT: vmsne.vi v0, v8, 0
@@ -606,7 +606,7 @@ define void @extract_v2i1_nxv32i1_26(<vscale x 32 x i1> %x, ptr %y) {
606606
; CHECK-NEXT: vsetvli a1, zero, e8, m4, ta, ma
607607
; CHECK-NEXT: vmv.v.i v8, 0
608608
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
609-
; CHECK-NEXT: vsetivli zero, 2, e8, m4, ta, ma
609+
; CHECK-NEXT: vsetivli zero, 2, e8, m2, ta, ma
610610
; CHECK-NEXT: vslidedown.vi v8, v8, 26
611611
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
612612
; CHECK-NEXT: vmsne.vi v0, v8, 0

0 commit comments

Comments
 (0)