Skip to content

Commit 89f119c

Browse files
mikhailramalhopreames
authored andcommitted
[RISCV] Update matchSplatAsGather to use the index of extract_elt if in-bounds (llvm#118873)
This is a follow-up to llvm#117878 and allows the usage of vrgather if the index we are accessing in VT is constant and within bounds. This patch replaces the previous behavior of bailing out if the length of the search vector is greater than the vector of elements we are searching for. Since matchSplatAsGather works on EXTRACT_VECTOR_ELT, and we know the index from which the element is extracted, we only need to check if we are doing an insert from a larger vector into a smaller one, in which we do an extract instead. Co-authored-by: Luke Lau [email protected] Co-authored-by: Philip Reames [email protected]
1 parent cac67d3 commit 89f119c

File tree

3 files changed

+439
-696
lines changed

3 files changed

+439
-696
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3522,27 +3522,43 @@ static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
35223522
// different
35233523
// FIXME: Support i1 vectors, maybe by promoting to i8?
35243524
MVT EltTy = VT.getVectorElementType();
3525-
if (EltTy == MVT::i1 ||
3526-
EltTy != Vec.getSimpleValueType().getVectorElementType())
3525+
MVT SrcVT = Vec.getSimpleValueType();
3526+
if (EltTy == MVT::i1 || EltTy != SrcVT.getVectorElementType())
35273527
return SDValue();
35283528
SDValue Idx = SplatVal.getOperand(1);
35293529
// The index must be a legal type.
35303530
if (Idx.getValueType() != Subtarget.getXLenVT())
35313531
return SDValue();
35323532

3533-
// Check that Index lies within VT
3534-
// TODO: Can we check if the Index is constant and known in-bounds?
3535-
if (!TypeSize::isKnownLE(Vec.getValueSizeInBits(), VT.getSizeInBits()))
3536-
return SDValue();
3533+
// Check that we know Idx lies within VT
3534+
if (!TypeSize::isKnownLE(SrcVT.getSizeInBits(), VT.getSizeInBits())) {
3535+
auto *CIdx = dyn_cast<ConstantSDNode>(Idx);
3536+
if (!CIdx || CIdx->getZExtValue() >= VT.getVectorMinNumElements())
3537+
return SDValue();
3538+
}
35373539

3540+
// Convert fixed length vectors to scalable
35383541
MVT ContainerVT = VT;
35393542
if (VT.isFixedLengthVector())
35403543
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
35413544

3542-
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
3543-
DAG.getUNDEF(ContainerVT), Vec,
3544-
DAG.getVectorIdxConstant(0, DL));
3545+
MVT SrcContainerVT = SrcVT;
3546+
if (SrcVT.isFixedLengthVector()) {
3547+
SrcContainerVT = getContainerForFixedLengthVector(DAG, SrcVT, Subtarget);
3548+
Vec = convertToScalableVector(SrcContainerVT, Vec, DAG, Subtarget);
3549+
}
3550+
3551+
// Put Vec in a VT sized vector
3552+
if (SrcContainerVT.getVectorMinNumElements() <
3553+
ContainerVT.getVectorMinNumElements())
3554+
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
3555+
DAG.getUNDEF(ContainerVT), Vec,
3556+
DAG.getVectorIdxConstant(0, DL));
3557+
else
3558+
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
3559+
DAG.getVectorIdxConstant(0, DL));
35453560

3561+
// We checked that Idx fits inside VT earlier
35463562
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
35473563

35483564
SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,

0 commit comments

Comments
 (0)