Skip to content

Commit aa68e28

Browse files
[RISCV] Support llvm.masked.compressstore intrinsic (#83457)
The changeset enables lowering of `llvm.masked.compressstore(%data, %ptr, %mask)` for RVV for fixed vector type into: ``` %0 = vcompress %data, %mask, %vl %new_vl = vcpop %mask, %vl vse %0, %ptr, %1, %new_vl ``` Such lowering is only possible when `%data` fits into available LMULs and otherwise `llvm.masked.compressstore` is scalarized by `ScalarizeMaskedMemIntrin` pass. Even though RVV spec in the section `15.8` provide alternative sequence for compressstore, use of `vcompress + vcpop` should be a proper canonical form to lower `llvm.masked.compressstore`. If RISC-V target find the sequence from `15.8` better, peephole optimization can transform `vcompress + vcpop` into that sequence.
1 parent b77c079 commit aa68e28

6 files changed

+1085
-1746
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10466,6 +10466,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
1046610466
SDValue BasePtr = MemSD->getBasePtr();
1046710467
SDValue Val, Mask, VL;
1046810468

10469+
bool IsCompressingStore = false;
1046910470
if (const auto *VPStore = dyn_cast<VPStoreSDNode>(Op)) {
1047010471
Val = VPStore->getValue();
1047110472
Mask = VPStore->getMask();
@@ -10474,9 +10475,11 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
1047410475
const auto *MStore = cast<MaskedStoreSDNode>(Op);
1047510476
Val = MStore->getValue();
1047610477
Mask = MStore->getMask();
10478+
IsCompressingStore = MStore->isCompressingStore();
1047710479
}
1047810480

10479-
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10481+
bool IsUnmasked =
10482+
ISD::isConstantSplatVectorAllOnes(Mask.getNode()) || IsCompressingStore;
1048010483

1048110484
MVT VT = Val.getSimpleValueType();
1048210485
MVT XLenVT = Subtarget.getXLenVT();
@@ -10486,7 +10489,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
1048610489
ContainerVT = getContainerForFixedLengthVector(VT);
1048710490

1048810491
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
10489-
if (!IsUnmasked) {
10492+
if (!IsUnmasked || IsCompressingStore) {
1049010493
MVT MaskVT = getMaskTypeFor(ContainerVT);
1049110494
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
1049210495
}
@@ -10495,6 +10498,15 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
1049510498
if (!VL)
1049610499
VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
1049710500

10501+
if (IsCompressingStore) {
10502+
Val = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
10503+
DAG.getConstant(Intrinsic::riscv_vcompress, DL, XLenVT),
10504+
DAG.getUNDEF(ContainerVT), Val, Mask, VL);
10505+
VL =
10506+
DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
10507+
getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG), VL);
10508+
}
10509+
1049810510
unsigned IntID =
1049910511
IsUnmasked ? Intrinsic::riscv_vse : Intrinsic::riscv_vse_mask;
1050010512
SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,3 +1620,13 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
16201620
C2.NumIVMuls, C2.NumBaseAdds,
16211621
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
16221622
}
1623+
1624+
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
1625+
auto *VTy = dyn_cast<VectorType>(DataTy);
1626+
if (!VTy || VTy->isScalableTy())
1627+
return false;
1628+
1629+
if (!isLegalMaskedLoadStore(DataTy, Alignment))
1630+
return false;
1631+
return true;
1632+
}

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
261261
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
262262
}
263263

264+
bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
265+
264266
bool isVScaleKnownToBeAPowerOfTwo() const {
265267
return TLI->isVScaleKnownToBeAPowerOfTwo();
266268
}

0 commit comments

Comments
 (0)