Skip to content

Commit b74e779

Browse files
authored
[x86] Add lowering for @llvm.experimental.vector.compress (#104904)
This is a follow-up to #92289 that adds lowering of the new `@llvm.experimental.vector.compress` intrinsic on x86 with AVX512 instructions. This intrinsic maps directly to `vpcompress`.
1 parent acf90fd commit b74e779

File tree

6 files changed

+1345
-7
lines changed

6 files changed

+1345
-7
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
963963
SDValue SplitVecOp_VP_REDUCE(SDNode *N, unsigned OpNo);
964964
SDValue SplitVecOp_UnaryOp(SDNode *N);
965965
SDValue SplitVecOp_TruncateHelper(SDNode *N);
966+
SDValue SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo);
966967

967968
SDValue SplitVecOp_BITCAST(SDNode *N);
968969
SDValue SplitVecOp_INSERT_SUBVECTOR(SDNode *N, unsigned OpNo);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,16 +2469,17 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
24692469
}
24702470

24712471
SDValue Passthru = N->getOperand(2);
2472-
if (!HasCustomLowering || !Passthru.isUndef()) {
2472+
if (!HasCustomLowering) {
24732473
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
24742474
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
24752475
return;
24762476
}
24772477

24782478
// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
2479+
SDValue Mask = N->getOperand(1);
24792480
SDValue LoMask, HiMask;
24802481
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
2481-
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
2482+
std::tie(LoMask, HiMask) = SplitMask(Mask);
24822483

24832484
SDValue UndefPassthru = DAG.getUNDEF(LoVT);
24842485
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
@@ -2502,6 +2503,10 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
25022503
MachinePointerInfo::getUnknownStack(MF));
25032504

25042505
SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
2506+
if (!Passthru.isUndef()) {
2507+
Compressed =
2508+
DAG.getNode(ISD::VSELECT, DL, VecVT, Mask, Compressed, Passthru);
2509+
}
25052510
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
25062511
}
25072512

@@ -3259,6 +3264,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
32593264
case ISD::VSELECT:
32603265
Res = SplitVecOp_VSELECT(N, OpNo);
32613266
break;
3267+
case ISD::VECTOR_COMPRESS:
3268+
Res = SplitVecOp_VECTOR_COMPRESS(N, OpNo);
3269+
break;
32623270
case ISD::STRICT_SINT_TO_FP:
32633271
case ISD::STRICT_UINT_TO_FP:
32643272
case ISD::SINT_TO_FP:
@@ -3413,6 +3421,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) {
34133421
return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0VT, LoSelect, HiSelect);
34143422
}
34153423

3424+
SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo) {
3425+
// The only possibility for an illegal operand is the mask, since result type
3426+
// legalization would have handled this node already otherwise.
3427+
assert(OpNo == 1 && "Illegal operand must be mask");
3428+
3429+
// To split the mask, we need to split the result type too, so we can just
3430+
// reuse that logic here.
3431+
SDValue Lo, Hi;
3432+
SplitVecRes_VECTOR_COMPRESS(N, Lo, Hi);
3433+
3434+
EVT VecVT = N->getValueType(0);
3435+
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VecVT, Lo, Hi);
3436+
}
3437+
34163438
SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo) {
34173439
EVT ResVT = N->getValueType(0);
34183440
SDValue Lo, Hi;

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11705,11 +11705,13 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
1170511705
// ... if it is not a splat vector, we need to get the passthru value at
1170611706
// position = popcount(mask) and re-load it from the stack before it is
1170711707
// overwritten in the loop below.
11708+
EVT PopcountVT = ScalarVT.changeTypeToInteger();
1170811709
SDValue Popcount = DAG.getNode(
1170911710
ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
11710-
Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
11711-
MaskVT.changeVectorElementType(ScalarVT), Popcount);
11712-
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, ScalarVT, Popcount);
11711+
Popcount =
11712+
DAG.getNode(ISD::ZERO_EXTEND, DL,
11713+
MaskVT.changeVectorElementType(PopcountVT), Popcount);
11714+
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, PopcountVT, Popcount);
1171311715
SDValue LastElmtPtr =
1171411716
getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
1171511717
LastWriteVal = DAG.getLoad(
@@ -11748,8 +11750,10 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
1174811750

1174911751
// Re-write the last ValI if all lanes were selected. Otherwise,
1175011752
// overwrite the last write it with the passthru value.
11751-
LastWriteVal =
11752-
DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI, LastWriteVal);
11753+
SDNodeFlags Flags{};
11754+
Flags.setUnpredictable(true);
11755+
LastWriteVal = DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI,
11756+
LastWriteVal, Flags);
1175311757
Chain = DAG.getStore(
1175411758
Chain, DL, LastWriteVal, OutPtr,
1175511759
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,6 +2134,35 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
21342134
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 })
21352135
setOperationAction(ISD::CTPOP, VT, Legal);
21362136
}
2137+
2138+
// We can try to convert vectors to different sizes to leverage legal
2139+
// `vpcompress` cases. So we mark these supported vector sizes as Custom and
2140+
// then specialize to Legal below.
2141+
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
2142+
MVT::v4f64, MVT::v2i64, MVT::v2f64, MVT::v16i8, MVT::v8i16,
2143+
MVT::v16i16, MVT::v8i8})
2144+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
2145+
2146+
// Legal vpcompress depends on various AVX512 extensions.
2147+
// Legal in AVX512F
2148+
for (MVT VT : {MVT::v16i32, MVT::v16f32, MVT::v8i64, MVT::v8f64})
2149+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
2150+
2151+
// Legal in AVX512F + AVX512VL
2152+
if (Subtarget.hasVLX())
2153+
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
2154+
MVT::v4f64, MVT::v2i64, MVT::v2f64})
2155+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
2156+
2157+
// Legal in AVX512F + AVX512VBMI2
2158+
if (Subtarget.hasVBMI2())
2159+
for (MVT VT : {MVT::v32i16, MVT::v64i8})
2160+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
2161+
2162+
// Legal in AVX512F + AVX512VL + AVX512VBMI2
2163+
if (Subtarget.hasVBMI2() && Subtarget.hasVLX())
2164+
for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v32i8, MVT::v16i16})
2165+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
21372166
}
21382167

21392168
// This block control legalization of v32i1/v64i1 which are available with
@@ -17795,6 +17824,68 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget,
1779517824
llvm_unreachable("Unimplemented!");
1779617825
}
1779717826

17827+
// As legal vpcompress instructions depend on various AVX512 extensions, try to
17828+
// convert illegal vector sizes to legal ones to avoid expansion.
17829+
static SDValue lowerVECTOR_COMPRESS(SDValue Op, const X86Subtarget &Subtarget,
17830+
SelectionDAG &DAG) {
17831+
assert(Subtarget.hasAVX512() &&
17832+
"Need AVX512 for custom VECTOR_COMPRESS lowering.");
17833+
17834+
SDLoc DL(Op);
17835+
SDValue Vec = Op.getOperand(0);
17836+
SDValue Mask = Op.getOperand(1);
17837+
SDValue Passthru = Op.getOperand(2);
17838+
17839+
EVT VecVT = Vec.getValueType();
17840+
EVT ElementVT = VecVT.getVectorElementType();
17841+
unsigned NumElements = VecVT.getVectorNumElements();
17842+
unsigned NumVecBits = VecVT.getFixedSizeInBits();
17843+
unsigned NumElementBits = ElementVT.getFixedSizeInBits();
17844+
17845+
// 128- and 256-bit vectors with <= 16 elements can be converted to and
17846+
// compressed as 512-bit vectors in AVX512F.
17847+
if (NumVecBits != 128 && NumVecBits != 256)
17848+
return SDValue();
17849+
17850+
if (NumElementBits == 32 || NumElementBits == 64) {
17851+
unsigned NumLargeElements = 512 / NumElementBits;
17852+
MVT LargeVecVT =
17853+
MVT::getVectorVT(ElementVT.getSimpleVT(), NumLargeElements);
17854+
MVT LargeMaskVT = MVT::getVectorVT(MVT::i1, NumLargeElements);
17855+
17856+
Vec = widenSubVector(LargeVecVT, Vec, /*ZeroNewElements=*/false, Subtarget,
17857+
DAG, DL);
17858+
Mask = widenSubVector(LargeMaskVT, Mask, /*ZeroNewElements=*/true,
17859+
Subtarget, DAG, DL);
17860+
Passthru = Passthru.isUndef() ? DAG.getUNDEF(LargeVecVT)
17861+
: widenSubVector(LargeVecVT, Passthru,
17862+
/*ZeroNewElements=*/false,
17863+
Subtarget, DAG, DL);
17864+
17865+
SDValue Compressed =
17866+
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
17867+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, Compressed,
17868+
DAG.getConstant(0, DL, MVT::i64));
17869+
}
17870+
17871+
if (VecVT == MVT::v8i16 || VecVT == MVT::v8i8 || VecVT == MVT::v16i8 ||
17872+
VecVT == MVT::v16i16) {
17873+
MVT LageElementVT = MVT::getIntegerVT(512 / NumElements);
17874+
EVT LargeVecVT = MVT::getVectorVT(LageElementVT, NumElements);
17875+
17876+
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Vec);
17877+
Passthru = Passthru.isUndef()
17878+
? DAG.getUNDEF(LargeVecVT)
17879+
: DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Passthru);
17880+
17881+
SDValue Compressed =
17882+
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
17883+
return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed);
17884+
}
17885+
17886+
return SDValue();
17887+
}
17888+
1779817889
/// Try to lower a VSELECT instruction to a vector shuffle.
1779917890
static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
1780017891
const X86Subtarget &Subtarget,
@@ -32621,6 +32712,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3262132712
case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG);
3262232713
case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, Subtarget, DAG);
3262332714
case ISD::VECTOR_SHUFFLE: return lowerVECTOR_SHUFFLE(Op, Subtarget, DAG);
32715+
case ISD::VECTOR_COMPRESS: return lowerVECTOR_COMPRESS(Op, Subtarget, DAG);
3262432716
case ISD::VSELECT: return LowerVSELECT(Op, DAG);
3262532717
case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3262632718
case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG);

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10546,6 +10546,12 @@ multiclass compress_by_vec_width_lowering<X86VectorVTInfo _, string Name> {
1054610546
def : Pat<(X86compress (_.VT _.RC:$src), _.ImmAllZerosV, _.KRCWM:$mask),
1054710547
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
1054810548
_.KRCWM:$mask, _.RC:$src)>;
10549+
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, undef)),
10550+
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
10551+
_.KRCWM:$mask, _.RC:$src)>;
10552+
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, _.RC:$passthru)),
10553+
(!cast<Instruction>(Name#_.ZSuffix#rrk)
10554+
_.RC:$passthru, _.KRCWM:$mask, _.RC:$src)>;
1054910555
}
1055010556

1055110557
multiclass compress_by_elt_width<bits<8> opc, string OpcodeStr,

0 commit comments

Comments
 (0)