Skip to content

Commit f14ae55

Browse files
committed
[RISCV] Support llvm.masked.expandload intrinsic
We can use `viota`+`vrgather` to synthesize `vdecompress` and lower expanding load to `vcpop`+`load`+`vdecompress`. And if `%mask` is all ones, we can lower expanding load to a normal unmasked load. Fixes #101914.
1 parent fba9f05 commit f14ae55

File tree

6 files changed

+20483
-1932
lines changed

6 files changed

+20483
-1932
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11122,6 +11122,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1112211122
SDValue BasePtr = MemSD->getBasePtr();
1112311123

1112411124
SDValue Mask, PassThru, VL;
11125+
bool IsExpandingLoad = false;
1112511126
if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1112611127
Mask = VPLoad->getMask();
1112711128
PassThru = DAG.getUNDEF(VT);
@@ -11130,6 +11131,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1113011131
const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1113111132
Mask = MLoad->getMask();
1113211133
PassThru = MLoad->getPassThru();
11134+
IsExpandingLoad = MLoad->isExpandingLoad();
1113311135
}
1113411136

1113511137
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -11149,25 +11151,59 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1114911151
if (!VL)
1115011152
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
1115111153

11152-
unsigned IntID =
11153-
IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
11154+
SDValue ExpandingVL;
11155+
if (!IsUnmasked && IsExpandingLoad) {
11156+
ExpandingVL = VL;
11157+
VL =
11158+
DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
11159+
getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG), VL);
11160+
}
11161+
11162+
unsigned IntID = IsUnmasked || IsExpandingLoad ? Intrinsic::riscv_vle
11163+
: Intrinsic::riscv_vle_mask;
1115411164
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11155-
if (IsUnmasked)
11165+
if (IntID == Intrinsic::riscv_vle)
1115611166
Ops.push_back(DAG.getUNDEF(ContainerVT));
1115711167
else
1115811168
Ops.push_back(PassThru);
1115911169
Ops.push_back(BasePtr);
11160-
if (!IsUnmasked)
11170+
if (IntID == Intrinsic::riscv_vle_mask)
1116111171
Ops.push_back(Mask);
1116211172
Ops.push_back(VL);
11163-
if (!IsUnmasked)
11173+
if (IntID == Intrinsic::riscv_vle_mask)
1116411174
Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
1116511175

1116611176
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
1116711177

1116811178
SDValue Result =
1116911179
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
1117011180
Chain = Result.getValue(1);
11181+
if (ExpandingVL) {
11182+
MVT IndexVT = ContainerVT;
11183+
if (ContainerVT.isFloatingPoint())
11184+
IndexVT = ContainerVT.changeVectorElementTypeToInteger();
11185+
11186+
MVT IndexEltVT = IndexVT.getVectorElementType();
11187+
bool UseVRGATHEREI16 = false;
11188+
// If index vector is an i8 vector and the element count exceeds 256, we
11189+
// should change the element type of index vector to i16 to avoid
11190+
// overflow.
11191+
if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
11192+
// FIXME: We need to do vector splitting manually for LMUL=8 cases.
11193+
assert(getLMUL(IndexVT) != RISCVII::LMUL_8);
11194+
IndexVT = IndexVT.changeVectorElementType(MVT::i16);
11195+
UseVRGATHEREI16 = true;
11196+
}
11197+
11198+
SDValue Iota =
11199+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
11200+
DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
11201+
DAG.getUNDEF(IndexVT), Mask, ExpandingVL);
11202+
Result =
11203+
DAG.getNode(UseVRGATHEREI16 ? RISCVISD::VRGATHEREI16_VV_VL
11204+
: RISCVISD::VRGATHER_VV_VL,
11205+
DL, ContainerVT, Result, Iota, PassThru, Mask, ExpandingVL);
11206+
}
1117111207

1117211208
if (VT.isFixedLengthVector())
1117311209
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,6 +2298,23 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
22982298
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
22992299
}
23002300

2301+
bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
2302+
auto *VTy = dyn_cast<VectorType>(DataTy);
2303+
if (!VTy || VTy->isScalableTy())
2304+
return false;
2305+
2306+
if (!isLegalMaskedLoadStore(DataTy, Alignment))
2307+
return false;
2308+
2309+
// FIXME: If it is an i8 vector and the element count exceeds 256, we should
2310+
// scalarize these types with LMUL >= maximum fixed-length LMUL.
2311+
if (VTy->getElementType()->isIntegerTy(8))
2312+
if (VTy->getElementCount().getFixedValue() > 256)
2313+
return VTy->getPrimitiveSizeInBits() / ST->getRealMinVLen() <
2314+
ST->getMaxLMULForFixedLengthVectors();
2315+
return true;
2316+
}
2317+
23012318
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
23022319
auto *VTy = dyn_cast<VectorType>(DataTy);
23032320
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
301301
DL);
302302
}
303303

304+
bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
305+
304306
bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
305307

306308
bool isVScaleKnownToBeAPowerOfTwo() const {

0 commit comments

Comments
 (0)