Skip to content

Commit aa47bfa

Browse files
committed
[RISCV] Refactor getDefaultVLOps. NFC.
Current getDefaultVLOps can only deduce VL from a MVT. However, sometimes users have already known VL value. This commit will provide a uniform interface to get VL instead of calling DAG.getConstant. Differential Revision: https://reviews.llvm.org/D138003
1 parent 5d19fea commit aa47bfa

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,17 +1847,32 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, SDLoc DL,
18471847
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
18481848
}
18491849

1850+
static SDValue getVLOp(uint64_t NumElts, SDLoc DL, SelectionDAG &DAG,
1851+
const RISCVSubtarget &Subtarget) {
1852+
return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
1853+
}
1854+
1855+
static std::pair<SDValue, SDValue>
1856+
getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
1857+
const RISCVSubtarget &Subtarget) {
1858+
assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
1859+
SDValue VL = getVLOp(NumElts, DL, DAG, Subtarget);
1860+
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
1861+
return {Mask, VL};
1862+
}
1863+
18501864
// Gets the two common "VL" operands: an all-ones mask and the vector length.
18511865
// VecVT is a vector type, either fixed-length or scalable, and ContainerVT is
18521866
// the vector type that it is contained in.
18531867
static std::pair<SDValue, SDValue>
18541868
getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
18551869
const RISCVSubtarget &Subtarget) {
1870+
if (VecVT.isFixedLengthVector())
1871+
return getDefaultVLOps(VecVT.getVectorNumElements(), ContainerVT, DL, DAG,
1872+
Subtarget);
18561873
assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
18571874
MVT XLenVT = Subtarget.getXLenVT();
1858-
SDValue VL = VecVT.isFixedLengthVector()
1859-
? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT)
1860-
: DAG.getRegister(RISCV::X0, XLenVT);
1875+
SDValue VL = DAG.getRegister(RISCV::X0, XLenVT);
18611876
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
18621877
return {Mask, VL};
18631878
}
@@ -5115,8 +5130,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
51155130
// If the index is 0, the vector is already in the right position.
51165131
if (!isNullConstant(Idx)) {
51175132
// Use a VL of 1 to avoid processing more elements than we need.
5118-
SDValue VL = DAG.getConstant(1, DL, XLenVT);
5119-
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
5133+
auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
51205134
Vec = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT,
51215135
DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
51225136
}
@@ -5486,7 +5500,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
54865500
MVT VT = Op->getSimpleValueType(0);
54875501
MVT ContainerVT = getContainerForFixedLengthVector(VT);
54885502

5489-
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
5503+
SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
54905504
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
54915505
auto *Load = cast<MemIntrinsicSDNode>(Op);
54925506
SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
@@ -5932,7 +5946,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
59325946
// Set the vector length to only the number of elements we care about. Note
59335947
// that for slideup this includes the offset.
59345948
SDValue VL =
5935-
DAG.getConstant(OrigIdx + SubVecVT.getVectorNumElements(), DL, XLenVT);
5949+
getVLOp(OrigIdx + SubVecVT.getVectorNumElements(), DL, DAG, Subtarget);
59365950
SDValue SlideupAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
59375951
SDValue Slideup = DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, ContainerVT, Vec,
59385952
SubVec, SlideupAmt, Mask, VL);
@@ -6078,7 +6092,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
60786092
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
60796093
// Set the vector length to only the number of elements we care about. This
60806094
// avoids sliding down elements we're going to discard straight away.
6081-
SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
6095+
SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), DL, DAG, Subtarget);
60826096
SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
60836097
SDValue Slidedown =
60846098
DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT,
@@ -6220,7 +6234,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
62206234
// Calculate VLMAX-1 for the desired SEW.
62216235
unsigned MinElts = VecVT.getVectorMinNumElements();
62226236
SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
6223-
DAG.getConstant(MinElts, DL, XLenVT));
6237+
getVLOp(MinElts, DL, DAG, Subtarget));
62246238
SDValue VLMinus1 =
62256239
DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, DAG.getConstant(1, DL, XLenVT));
62266240

@@ -6252,7 +6266,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,
62526266

62536267
unsigned MinElts = VecVT.getVectorMinNumElements();
62546268
SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
6255-
DAG.getConstant(MinElts, DL, XLenVT));
6269+
getVLOp(MinElts, DL, DAG, Subtarget));
62566270

62576271
int64_t ImmValue = cast<ConstantSDNode>(Op.getOperand(2))->getSExtValue();
62586272
SDValue DownOffset, UpOffset;
@@ -6292,7 +6306,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
62926306
MVT XLenVT = Subtarget.getXLenVT();
62936307
MVT ContainerVT = getContainerForFixedLengthVector(VT);
62946308

6295-
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
6309+
SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
62966310

62976311
bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
62986312
SDValue IntID = DAG.getTargetConstant(
@@ -6336,7 +6350,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
63366350

63376351
MVT ContainerVT = getContainerForFixedLengthVector(VT);
63386352

6339-
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
6353+
SDValue VL = getVLOp(VT.getVectorNumElements(), DL, DAG, Subtarget);
63406354

63416355
SDValue NewValue =
63426356
convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
@@ -6482,11 +6496,9 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
64826496
convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
64836497

64846498
SDLoc DL(Op);
6485-
SDValue VL =
6486-
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
6487-
6499+
auto [Mask, VL] = getDefaultVLOps(VT.getVectorNumElements(), ContainerVT, DL,
6500+
DAG, Subtarget);
64886501
MVT MaskVT = getMaskTypeFor(ContainerVT);
6489-
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
64906502

64916503
SDValue Cmp =
64926504
DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
@@ -7720,8 +7732,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
77207732
MVT XLenVT = Subtarget.getXLenVT();
77217733

77227734
// Use a VL of 1 to avoid processing more elements than we need.
7723-
SDValue VL = DAG.getConstant(1, DL, XLenVT);
7724-
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
7735+
auto [Mask, VL] = getDefaultVLOps(1, ContainerVT, DL, DAG, Subtarget);
77257736

77267737
// Unless the index is known to be 0, we must slide the vector down to get
77277738
// the desired element into index 0.
@@ -7783,8 +7794,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
77837794

77847795
// To extract the upper XLEN bits of the vector element, shift the first
77857796
// element right by 32 bits and re-extract the lower XLEN bits.
7786-
SDValue VL = DAG.getConstant(1, DL, XLenVT);
7787-
SDValue Mask = getAllOnesMask(VecVT, VL, DL, DAG);
7797+
auto [Mask, VL] = getDefaultVLOps(1, VecVT, DL, DAG, Subtarget);
77887798

77897799
SDValue ThirtyTwoV =
77907800
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT),

0 commit comments

Comments
 (0)