@@ -10732,33 +10732,44 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
10732
10732
SDValue Chain = MemSD->getChain();
10733
10733
SDValue BasePtr = MemSD->getBasePtr();
10734
10734
10735
- SDValue Mask, PassThru, VL;
10735
+ SDValue Mask, PassThru, LoadVL;
10736
+ bool IsExpandingLoad = false;
10736
10737
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
10737
10738
Mask = VPLoad->getMask();
10738
10739
PassThru = DAG.getUNDEF(VT);
10739
- VL = VPLoad->getVectorLength();
10740
+ LoadVL = VPLoad->getVectorLength();
10740
10741
} else {
10741
10742
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
10742
10743
Mask = MLoad->getMask();
10743
10744
PassThru = MLoad->getPassThru();
10745
+ IsExpandingLoad = MLoad->isExpandingLoad();
10744
10746
}
10745
10747
10746
- bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10748
+ bool IsUnmasked =
10749
+ ISD::isConstantSplatVectorAllOnes(Mask.getNode()) || IsExpandingLoad;
10747
10750
10748
10751
MVT XLenVT = Subtarget.getXLenVT();
10749
10752
10750
10753
MVT ContainerVT = VT;
10751
10754
if (VT.isFixedLengthVector()) {
10752
10755
ContainerVT = getContainerForFixedLengthVector(VT);
10753
10756
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
10754
- if (!IsUnmasked) {
10757
+ if (!IsUnmasked || IsExpandingLoad ) {
10755
10758
MVT MaskVT = getMaskTypeFor(ContainerVT);
10756
10759
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10757
10760
}
10758
10761
}
10759
10762
10760
- if (!VL)
10761
- VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10763
+ if (!LoadVL)
10764
+ LoadVL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10765
+
10766
+ SDValue ExpandingVL;
10767
+ if (IsExpandingLoad) {
10768
+ ExpandingVL = LoadVL;
10769
+ LoadVL = DAG.getNode(
10770
+ RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
10771
+ getAllOnesMask(Mask.getSimpleValueType(), LoadVL, DL, DAG), LoadVL);
10772
+ }
10762
10773
10763
10774
unsigned IntID =
10764
10775
IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
@@ -10770,7 +10781,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
10770
10781
Ops.push_back(BasePtr);
10771
10782
if (!IsUnmasked)
10772
10783
Ops.push_back(Mask);
10773
- Ops.push_back(VL );
10784
+ Ops.push_back(LoadVL );
10774
10785
if (!IsUnmasked)
10775
10786
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
10776
10787
@@ -10779,6 +10790,18 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
10779
10790
SDValue Result =
10780
10791
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
10781
10792
Chain = Result.getValue(1);
10793
+ if (IsExpandingLoad) {
10794
+ MVT IotaVT = ContainerVT;
10795
+ if (ContainerVT.isFloatingPoint())
10796
+ IotaVT = ContainerVT.changeVectorElementTypeToInteger();
10797
+
10798
+ SDValue Iota =
10799
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IotaVT,
10800
+ DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
10801
+ DAG.getUNDEF(IotaVT), Mask, ExpandingVL);
10802
+ Result = DAG.getNode(RISCVISD::VRGATHER_VV_VL, DL, ContainerVT, Result,
10803
+ Iota, PassThru, Mask, ExpandingVL);
10804
+ }
10782
10805
10783
10806
if (VT.isFixedLengthVector())
10784
10807
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
0 commit comments