@@ -11107,6 +11107,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11107
11107
SDValue BasePtr = MemSD->getBasePtr();
11108
11108
11109
11109
SDValue Mask, PassThru, VL;
11110
+ bool IsExpandingLoad = false;
11110
11111
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
11111
11112
Mask = VPLoad->getMask();
11112
11113
PassThru = DAG.getUNDEF(VT);
@@ -11115,6 +11116,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11115
11116
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
11116
11117
Mask = MLoad->getMask();
11117
11118
PassThru = MLoad->getPassThru();
11119
+ IsExpandingLoad = MLoad->isExpandingLoad();
11118
11120
}
11119
11121
11120
11122
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -11134,16 +11136,38 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
11134
11136
if (!VL)
11135
11137
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
11136
11138
11137
- unsigned IntID =
11138
- IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
11139
+ SDValue Index;
11140
+ if (!IsUnmasked && IsExpandingLoad) {
11141
+ MVT IndexVT = ContainerVT;
11142
+ if (ContainerVT.isFloatingPoint())
11143
+ IndexVT = IndexVT.changeVectorElementTypeToInteger();
11144
+
11145
+ if (Subtarget.isRV32() && IndexVT.getVectorElementType().bitsGT(XLenVT))
11146
+ IndexVT = IndexVT.changeVectorElementType(XLenVT);
11147
+
11148
+ Index = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
11149
+ DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
11150
+ DAG.getUNDEF(IndexVT), Mask, VL);
11151
+ if (uint64_t EltSize = ContainerVT.getScalarSizeInBits(); EltSize > 8)
11152
+ Index = DAG.getNode(RISCVISD::SHL_VL, DL, IndexVT, Index,
11153
+ DAG.getConstant(Log2_64(EltSize / 8), DL, IndexVT),
11154
+ DAG.getUNDEF(IndexVT), Mask, VL);
11155
+ }
11156
+
11157
+ unsigned IntID = IsUnmasked ? Intrinsic::riscv_vle
11158
+ : IsExpandingLoad ? Intrinsic::riscv_vluxei_mask
11159
+ : Intrinsic::riscv_vle_mask;
11139
11160
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11140
11161
if (IsUnmasked)
11141
11162
Ops.push_back(DAG.getUNDEF(ContainerVT));
11142
11163
else
11143
11164
Ops.push_back(PassThru);
11144
11165
Ops.push_back(BasePtr);
11145
- if (!IsUnmasked)
11166
+ if (!IsUnmasked) {
11167
+ if (IsExpandingLoad)
11168
+ Ops.push_back(Index);
11146
11169
Ops.push_back(Mask);
11170
+ }
11147
11171
Ops.push_back(VL);
11148
11172
if (!IsUnmasked)
11149
11173
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
0 commit comments