Skip to content

Commit 81cf472

Browse files
committed
[RISCV] Support llvm.masked.expandload intrinsic
We can use `viota.m` + indexed load to synthesize expanding load: ``` %res = llvm.masked.expandload(%ptr, %mask, %passthru) -> %index = viota %mask if elt_size > 8: %index = vsll.vi %index, log2(elt_size), %mask %res = vluxei<n> %passthru, %ptr, %index, %mask ``` And if `%mask` is all ones, we can lower expanding load to a normal unmasked load. Fixes llvm#101914
1 parent fe85566 commit 81cf472

File tree

6 files changed

+1768
-1786
lines changed

6 files changed

+1768
-1786
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10733,6 +10733,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1073310733
SDValue BasePtr = MemSD->getBasePtr();
1073410734

1073510735
SDValue Mask, PassThru, VL;
10736+
bool IsExpandingLoad = false;
1073610737
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1073710738
Mask = VPLoad->getMask();
1073810739
PassThru = DAG.getUNDEF(VT);
@@ -10741,6 +10742,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1074110742
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1074210743
Mask = MLoad->getMask();
1074310744
PassThru = MLoad->getPassThru();
10745+
IsExpandingLoad = MLoad->isExpandingLoad();
1074410746
}
1074510747

1074610748
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -10760,16 +10762,38 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1076010762
if (!VL)
1076110763
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
1076210764

10763-
unsigned IntID =
10764-
IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
10765+
SDValue Index;
10766+
if (!IsUnmasked && IsExpandingLoad) {
10767+
MVT IndexVT = ContainerVT;
10768+
if (ContainerVT.isFloatingPoint())
10769+
IndexVT = IndexVT.changeVectorElementTypeToInteger();
10770+
10771+
if (Subtarget.isRV32() && IndexVT.getVectorElementType().bitsGT(XLenVT))
10772+
IndexVT = IndexVT.changeVectorElementType(XLenVT);
10773+
10774+
Index = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
10775+
DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
10776+
DAG.getUNDEF(IndexVT), Mask, VL);
10777+
if (uint64_t EltSize = ContainerVT.getScalarSizeInBits(); EltSize > 8)
10778+
Index = DAG.getNode(RISCVISD::SHL_VL, DL, IndexVT, Index,
10779+
DAG.getConstant(Log2_64(EltSize / 8), DL, IndexVT),
10780+
DAG.getUNDEF(IndexVT), Mask, VL);
10781+
}
10782+
10783+
unsigned IntID = IsUnmasked ? Intrinsic::riscv_vle
10784+
: IsExpandingLoad ? Intrinsic::riscv_vluxei_mask
10785+
: Intrinsic::riscv_vle_mask;
1076510786
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
1076610787
if (IsUnmasked)
1076710788
Ops.push_back(DAG.getUNDEF(ContainerVT));
1076810789
else
1076910790
Ops.push_back(PassThru);
1077010791
Ops.push_back(BasePtr);
10771-
if (!IsUnmasked)
10792+
if (!IsUnmasked) {
10793+
if (IsExpandingLoad)
10794+
Ops.push_back(Index);
1077210795
Ops.push_back(Mask);
10796+
}
1077310797
Ops.push_back(VL);
1077410798
if (!IsUnmasked)
1077510799
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,6 +1966,16 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
19661966
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
19671967
}
19681968

1969+
bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
1970+
auto *VTy = dyn_cast<VectorType>(DataTy);
1971+
if (!VTy || VTy->isScalableTy())
1972+
return false;
1973+
1974+
if (!isLegalMaskedLoadStore(DataTy, Alignment))
1975+
return false;
1976+
return true;
1977+
}
1978+
19691979
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
19701980
auto *VTy = dyn_cast<VectorType>(DataTy);
19711981
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
281281
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
282282
}
283283

284+
bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
285+
284286
bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
285287

286288
bool isVScaleKnownToBeAPowerOfTwo() const {

0 commit comments

Comments
 (0)