@@ -11122,6 +11122,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11122
11122
SDValue BasePtr = MemSD->getBasePtr();
11123
11123
11124
11124
SDValue Mask, PassThru, VL;
11125
+ bool IsExpandingLoad = false;
11125
11126
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
11126
11127
Mask = VPLoad->getMask();
11127
11128
PassThru = DAG.getUNDEF(VT);
@@ -11130,6 +11131,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11130
11131
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
11131
11132
Mask = MLoad->getMask();
11132
11133
PassThru = MLoad->getPassThru();
11134
+ IsExpandingLoad = MLoad->isExpandingLoad();
11133
11135
}
11134
11136
11135
11137
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -11149,25 +11151,59 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11149
11151
if (!VL)
11150
11152
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
11151
11153
11152
- unsigned IntID =
11153
- IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
11154
+ SDValue ExpandingVL;
11155
+ if (!IsUnmasked && IsExpandingLoad) {
11156
+ ExpandingVL = VL;
11157
+ VL =
11158
+ DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
11159
+ getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG), VL);
11160
+ }
11161
+
11162
+ unsigned IntID = IsUnmasked || IsExpandingLoad ? Intrinsic::riscv_vle
11163
+ : Intrinsic::riscv_vle_mask;
11154
11164
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11155
- if (IsUnmasked )
11165
+ if (IntID == Intrinsic::riscv_vle )
11156
11166
Ops.push_back(DAG.getUNDEF(ContainerVT));
11157
11167
else
11158
11168
Ops.push_back(PassThru);
11159
11169
Ops.push_back(BasePtr);
11160
- if (!IsUnmasked )
11170
+ if (IntID == Intrinsic::riscv_vle_mask )
11161
11171
Ops.push_back(Mask);
11162
11172
Ops.push_back(VL);
11163
- if (!IsUnmasked )
11173
+ if (IntID == Intrinsic::riscv_vle_mask )
11164
11174
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
11165
11175
11166
11176
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
11167
11177
11168
11178
SDValue Result =
11169
11179
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
11170
11180
Chain = Result.getValue(1);
11181
+ if (ExpandingVL) {
11182
+ MVT IndexVT = ContainerVT;
11183
+ if (ContainerVT.isFloatingPoint())
11184
+ IndexVT = ContainerVT.changeVectorElementTypeToInteger();
11185
+
11186
+ MVT IndexEltVT = IndexVT.getVectorElementType();
11187
+ bool UseVRGATHEREI16 = false;
11188
+ // If index vector is an i8 vector and the element count exceeds 256, we
11189
+ // should change the element type of index vector to i16 to avoid
11190
+ // overflow.
11191
+ if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
11192
+ // FIXME: We need to do vector splitting manually for LMUL=8 cases.
11193
+ assert(getLMUL(IndexVT) != RISCVII::LMUL_8);
11194
+ IndexVT = IndexVT.changeVectorElementType(MVT::i16);
11195
+ UseVRGATHEREI16 = true;
11196
+ }
11197
+
11198
+ SDValue Iota =
11199
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
11200
+ DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
11201
+ DAG.getUNDEF(IndexVT), Mask, ExpandingVL);
11202
+ Result =
11203
+ DAG.getNode(UseVRGATHEREI16 ? RISCVISD::VRGATHEREI16_VV_VL
11204
+ : RISCVISD::VRGATHER_VV_VL,
11205
+ DL, ContainerVT, Result, Iota, PassThru, Mask, ExpandingVL);
11206
+ }
11171
11207
11172
11208
if (VT.isFixedLengthVector())
11173
11209
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
0 commit comments