Skip to content

Commit 9e4b22d

Browse files
committed
[RISCV] Support llvm.masked.expandload intrinsic
We can use `viota.m` + indexed load to synthesize expanding load: ``` %res = llvm.masked.expandload(%ptr, %mask, %passthru) -> %index = viota %mask if elt_size > 8: %index = vsll.vi %index, log2(elt_size), %mask %res = vluxei<n> %passthru, %ptr, %index, %mask ``` And if `%mask` is all ones, we can lower expanding load to a normal unmasked load. Fixes llvm#101914
1 parent f1ade1f commit 9e4b22d

File tree

6 files changed

+1763
-1771
lines changed

6 files changed

+1763
-1771
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11107,6 +11107,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1110711107
SDValue BasePtr = MemSD->getBasePtr();
1110811108

1110911109
SDValue Mask, PassThru, VL;
11110+
bool IsExpandingLoad = false;
1111011111
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1111111112
Mask = VPLoad->getMask();
1111211113
PassThru = DAG.getUNDEF(VT);
@@ -11115,6 +11116,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1111511116
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1111611117
Mask = MLoad->getMask();
1111711118
PassThru = MLoad->getPassThru();
11119+
IsExpandingLoad = MLoad->isExpandingLoad();
1111811120
}
1111911121

1112011122
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -11134,16 +11136,38 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1113411136
if (!VL)
1113511137
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
1113611138

11137-
unsigned IntID =
11138-
IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
11139+
SDValue Index;
11140+
if (!IsUnmasked && IsExpandingLoad) {
11141+
MVT IndexVT = ContainerVT;
11142+
if (ContainerVT.isFloatingPoint())
11143+
IndexVT = IndexVT.changeVectorElementTypeToInteger();
11144+
11145+
if (Subtarget.isRV32() && IndexVT.getVectorElementType().bitsGT(XLenVT))
11146+
IndexVT = IndexVT.changeVectorElementType(XLenVT);
11147+
11148+
Index = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
11149+
DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
11150+
DAG.getUNDEF(IndexVT), Mask, VL);
11151+
if (uint64_t EltSize = ContainerVT.getScalarSizeInBits(); EltSize > 8)
11152+
Index = DAG.getNode(RISCVISD::SHL_VL, DL, IndexVT, Index,
11153+
DAG.getConstant(Log2_64(EltSize / 8), DL, IndexVT),
11154+
DAG.getUNDEF(IndexVT), Mask, VL);
11155+
}
11156+
11157+
unsigned IntID = IsUnmasked ? Intrinsic::riscv_vle
11158+
: IsExpandingLoad ? Intrinsic::riscv_vluxei_mask
11159+
: Intrinsic::riscv_vle_mask;
1113911160
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
1114011161
if (IsUnmasked)
1114111162
Ops.push_back(DAG.getUNDEF(ContainerVT));
1114211163
else
1114311164
Ops.push_back(PassThru);
1114411165
Ops.push_back(BasePtr);
11145-
if (!IsUnmasked)
11166+
if (!IsUnmasked) {
11167+
if (IsExpandingLoad)
11168+
Ops.push_back(Index);
1114611169
Ops.push_back(Mask);
11170+
}
1114711171
Ops.push_back(VL);
1114811172
if (!IsUnmasked)
1114911173
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,16 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
22862286
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
22872287
}
22882288

2289+
bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
2290+
auto *VTy = dyn_cast<VectorType>(DataTy);
2291+
if (!VTy || VTy->isScalableTy())
2292+
return false;
2293+
2294+
if (!isLegalMaskedLoadStore(DataTy, Alignment))
2295+
return false;
2296+
return true;
2297+
}
2298+
22892299
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
22902300
auto *VTy = dyn_cast<VectorType>(DataTy);
22912301
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
301301
DL);
302302
}
303303

304+
bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
305+
304306
bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
305307

306308
bool isVScaleKnownToBeAPowerOfTwo() const {

0 commit comments

Comments
 (0)