Skip to content

Commit aa662d7

Browse files
committed
[RISCV] Support llvm.masked.expandload intrinsic
We can use `iota+vrgather` to synthesize `vdecompress` and lower expanding load to `vcpop+load+vdecompress`. Fixes #101914
1 parent fe85566 commit aa662d7

File tree

6 files changed

+1874
-1846
lines changed

6 files changed

+1874
-1846
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10732,33 +10732,44 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1073210732
SDValue Chain = MemSD->getChain();
1073310733
SDValue BasePtr = MemSD->getBasePtr();
1073410734

10735-
SDValue Mask, PassThru, VL;
10735+
SDValue Mask, PassThru, LoadVL;
10736+
bool IsExpandingLoad = false;
1073610737
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1073710738
Mask = VPLoad->getMask();
1073810739
PassThru = DAG.getUNDEF(VT);
10739-
VL = VPLoad->getVectorLength();
10740+
LoadVL = VPLoad->getVectorLength();
1074010741
} else {
1074110742
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1074210743
Mask = MLoad->getMask();
1074310744
PassThru = MLoad->getPassThru();
10745+
IsExpandingLoad = MLoad->isExpandingLoad();
1074410746
}
1074510747

10746-
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10748+
bool IsUnmasked =
10749+
ISD::isConstantSplatVectorAllOnes(Mask.getNode()) || IsExpandingLoad;
1074710750

1074810751
MVT XLenVT = Subtarget.getXLenVT();
1074910752

1075010753
MVT ContainerVT = VT;
1075110754
if (VT.isFixedLengthVector()) {
1075210755
ContainerVT = getContainerForFixedLengthVector(VT);
1075310756
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
10754-
if (!IsUnmasked) {
10757+
if (!IsUnmasked || IsExpandingLoad) {
1075510758
MVT MaskVT = getMaskTypeFor(ContainerVT);
1075610759
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
1075710760
}
1075810761
}
1075910762

10760-
if (!VL)
10761-
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10763+
if (!LoadVL)
10764+
LoadVL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10765+
10766+
SDValue ExpandingVL;
10767+
if (IsExpandingLoad) {
10768+
ExpandingVL = LoadVL;
10769+
LoadVL = DAG.getNode(
10770+
RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
10771+
getAllOnesMask(Mask.getSimpleValueType(), LoadVL, DL, DAG), LoadVL);
10772+
}
1076210773

1076310774
unsigned IntID =
1076410775
IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
@@ -10770,7 +10781,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1077010781
Ops.push_back(BasePtr);
1077110782
if (!IsUnmasked)
1077210783
Ops.push_back(Mask);
10773-
Ops.push_back(VL);
10784+
Ops.push_back(LoadVL);
1077410785
if (!IsUnmasked)
1077510786
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
1077610787

@@ -10779,6 +10790,18 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1077910790
SDValue Result =
1078010791
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
1078110792
Chain = Result.getValue(1);
10793+
if (IsExpandingLoad) {
10794+
MVT IotaVT = ContainerVT;
10795+
if (ContainerVT.isFloatingPoint())
10796+
IotaVT = ContainerVT.changeVectorElementTypeToInteger();
10797+
10798+
SDValue Iota =
10799+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IotaVT,
10800+
DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
10801+
DAG.getUNDEF(IotaVT), Mask, ExpandingVL);
10802+
Result = DAG.getNode(RISCVISD::VRGATHER_VV_VL, DL, ContainerVT, Result,
10803+
Iota, PassThru, Mask, ExpandingVL);
10804+
}
1078210805

1078310806
if (VT.isFixedLengthVector())
1078410807
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,6 +1966,16 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
19661966
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
19671967
}
19681968

1969+
bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
1970+
auto *VTy = dyn_cast<VectorType>(DataTy);
1971+
if (!VTy || VTy->isScalableTy())
1972+
return false;
1973+
1974+
if (!isLegalMaskedLoadStore(DataTy, Alignment))
1975+
return false;
1976+
return true;
1977+
}
1978+
19691979
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
19701980
auto *VTy = dyn_cast<VectorType>(DataTy);
19711981
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
281281
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
282282
}
283283

284+
bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
285+
284286
bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
285287

286288
bool isVScaleKnownToBeAPowerOfTwo() const {

0 commit comments

Comments
 (0)