Skip to content
This repository was archived by the owner on Sep 2, 2018. It is now read-only.

Commit 828f5b8

Browse files
committed
[x86] Implement a faster vector population count based on the PSHUFB
in-register LUT technique. Summary: A description of this technique can be found here: http://wm.ite.pl/articles/sse-popcount.html The core of the idea is to use an in-register lookup table and the PSHUFB instruction to compute the population count for the low and high nibbles of each byte, and then to use horizontal sums to aggregate these into vector population counts with wider element types. On x86 there is an instruction that will directly compute the horizontal sum for the low 8 and high 8 bytes, giving vNi64 popcount very easily. Various tricks are used to get vNi32 and vNi16 from the vNi8 that the LUT computes. The base implemantion of this, and most of the work, was done by Bruno in a follow up to D6531. See Bruno's detailed post there for lots of timing information about these changes. I have extended Bruno's patch in the following ways: 0) I committed the new tests with baseline sequences so this shows a diff, and regenerated the tests using the update scripts. 1) Bruno had noticed and mentioned in IRC a redundant mask that I removed. 2) I introduced a particular optimization for the i32 vector cases where we use PSHL + PSADBW to compute the the low i32 popcounts, and PSHUFD + PSADBW to compute doubled high i32 popcounts. This takes advantage of the fact that to line up the high i32 popcounts we have to shift them anyways, and we can shift them by one fewer bit to effectively divide the count by two. While the PSHUFD based horizontal add is no faster, it doesn't require registers or load traffic the way a mask would, and provides more ILP as it happens on different ports with high throughput. 3) I did some code cleanups throughout to simplify the implementation logic. 4) I refactored it to continue to use the parallel bitmath lowering when SSSE3 is not available to preserve the performance of that version on SSE2 targets where it is still much better than scalarizing as we'll still do a bitmath implementation of popcount even in scalar code there. With #1 and #2 above, I analyzed the result in IACA for sandybridge, ivybridge, and haswell. In every case I measured, the throughput is the same or better using the LUT lowering, even v2i64 and v4i64, and even compared with using the native popcnt instruction! The latency of the LUT lowering is often higher than the latency of the scalarized popcnt instruction sequence, but I think those latency measurements are deeply misleading. Keeping the operation fully in the vector unit and having many chances for increased throughput seems much more likely to win. With this, we can lower every integer vector popcount implementation using the LUT strategy if we have SSSE3 or better (and thus have PSHUFB). I've updated the operation lowering to reflect this. This also fixes an issue where we were scalarizing horribly some AVX lowerings. Finally, there are some remaining cleanups. There is duplication between the two techniques in how they perform the horizontal sum once the byte population count is computed. I'm going to factor and merge those two in a separate follow-up commit. Differential Revision: http://reviews.llvm.org/D10084 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238636 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 43d1e87 commit 828f5b8

File tree

6 files changed

+698
-2054
lines changed

6 files changed

+698
-2054
lines changed

lib/Target/X86/X86ISelLowering.cpp

+176-28
Original file line numberDiff line numberDiff line change
@@ -842,15 +842,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
842842
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i32, Custom);
843843
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f32, Custom);
844844

845-
// Only provide customized ctpop vector bit twiddling for vector types we
846-
// know to perform better than using the popcnt instructions on each vector
847-
// element. If popcnt isn't supported, always provide the custom version.
848-
if (!Subtarget->hasPOPCNT()) {
849-
setOperationAction(ISD::CTPOP, MVT::v2i64, Custom);
850-
setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
851-
setOperationAction(ISD::CTPOP, MVT::v8i16, Custom);
852-
setOperationAction(ISD::CTPOP, MVT::v16i8, Custom);
853-
}
845+
setOperationAction(ISD::CTPOP, MVT::v16i8, Custom);
846+
setOperationAction(ISD::CTPOP, MVT::v8i16, Custom);
847+
setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
848+
setOperationAction(ISD::CTPOP, MVT::v2i64, Custom);
854849

855850
// Custom lower build_vector, vector_shuffle, and extract_vector_elt.
856851
for (int i = MVT::v16i8; i != MVT::v2i64; ++i) {
@@ -1115,6 +1110,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
11151110
setOperationAction(ISD::TRUNCATE, MVT::v8i16, Custom);
11161111
setOperationAction(ISD::TRUNCATE, MVT::v4i32, Custom);
11171112

1113+
setOperationAction(ISD::CTPOP, MVT::v32i8, Custom);
1114+
setOperationAction(ISD::CTPOP, MVT::v16i16, Custom);
1115+
setOperationAction(ISD::CTPOP, MVT::v8i32, Custom);
1116+
setOperationAction(ISD::CTPOP, MVT::v4i64, Custom);
1117+
11181118
if (Subtarget->hasFMA() || Subtarget->hasFMA4()) {
11191119
setOperationAction(ISD::FMA, MVT::v8f32, Legal);
11201120
setOperationAction(ISD::FMA, MVT::v4f64, Legal);
@@ -1149,16 +1149,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
11491149
// when we have a 256bit-wide blend with immediate.
11501150
setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Custom);
11511151

1152-
// Only provide customized ctpop vector bit twiddling for vector types we
1153-
// know to perform better than using the popcnt instructions on each
1154-
// vector element. If popcnt isn't supported, always provide the custom
1155-
// version.
1156-
if (!Subtarget->hasPOPCNT())
1157-
setOperationAction(ISD::CTPOP, MVT::v4i64, Custom);
1158-
1159-
// Custom CTPOP always performs better on natively supported v8i32
1160-
setOperationAction(ISD::CTPOP, MVT::v8i32, Custom);
1161-
11621152
// AVX2 also has wider vector sign/zero extending loads, VPMOV[SZ]X
11631153
setLoadExtAction(ISD::SEXTLOAD, MVT::v16i16, MVT::v16i8, Legal);
11641154
setLoadExtAction(ISD::SEXTLOAD, MVT::v8i32, MVT::v8i8, Legal);
@@ -17329,12 +17319,164 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
1732917319
return SDValue();
1733017320
}
1733117321

17322+
static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
17323+
const X86Subtarget *Subtarget,
17324+
SelectionDAG &DAG) {
17325+
EVT VT = Op.getValueType();
17326+
MVT EltVT = VT.getVectorElementType().getSimpleVT();
17327+
unsigned VecSize = VT.getSizeInBits();
17328+
17329+
// Implement a lookup table in register by using an algorithm based on:
17330+
// http://wm.ite.pl/articles/sse-popcount.html
17331+
//
17332+
// The general idea is that every lower byte nibble in the input vector is an
17333+
// index into a in-register pre-computed pop count table. We then split up the
17334+
// input vector in two new ones: (1) a vector with only the shifted-right
17335+
// higher nibbles for each byte and (2) a vector with the lower nibbles (and
17336+
// masked out higher ones) for each byte. PSHUB is used separately with both
17337+
// to index the in-register table. Next, both are added and the result is a
17338+
// i8 vector where each element contains the pop count for input byte.
17339+
//
17340+
// To obtain the pop count for elements != i8, we follow up with the same
17341+
// approach and use additional tricks as described below.
17342+
//
17343+
const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
17344+
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
17345+
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
17346+
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4};
17347+
17348+
int NumByteElts = VecSize / 8;
17349+
MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts);
17350+
SDValue In = DAG.getNode(ISD::BITCAST, DL, ByteVecVT, Op);
17351+
SmallVector<SDValue, 16> LUTVec;
17352+
for (int i = 0; i < NumByteElts; ++i)
17353+
LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
17354+
SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, LUTVec);
17355+
SmallVector<SDValue, 16> Mask0F(NumByteElts,
17356+
DAG.getConstant(0x0F, DL, MVT::i8));
17357+
SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Mask0F);
17358+
17359+
// High nibbles
17360+
SmallVector<SDValue, 16> Four(NumByteElts, DAG.getConstant(4, DL, MVT::i8));
17361+
SDValue FourV = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Four);
17362+
SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV);
17363+
17364+
// Low nibbles
17365+
SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F);
17366+
17367+
// The input vector is used as the shuffle mask that index elements into the
17368+
// LUT. After counting low and high nibbles, add the vector to obtain the
17369+
// final pop count per i8 element.
17370+
SDValue HighPopCnt =
17371+
DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles);
17372+
SDValue LowPopCnt =
17373+
DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles);
17374+
SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt);
17375+
17376+
if (EltVT == MVT::i8)
17377+
return PopCnt;
17378+
17379+
// PSADBW instruction horizontally add all bytes and leave the result in i64
17380+
// chunks, thus directly computes the pop count for v2i64 and v4i64.
17381+
if (EltVT == MVT::i64) {
17382+
SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
17383+
PopCnt = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, PopCnt, Zeros);
17384+
return DAG.getNode(ISD::BITCAST, DL, VT, PopCnt);
17385+
}
17386+
17387+
int NumI64Elts = VecSize / 64;
17388+
MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts);
17389+
17390+
if (EltVT == MVT::i32) {
17391+
// We unpack the low half and high half into i32s interleaved with zeros so
17392+
// that we can use PSADBW to horizontally sum them. The most useful part of
17393+
// this is that it lines up the results of two PSADBW instructions to be
17394+
// two v2i64 vectors which concatenated are the 4 population counts. We can
17395+
// then use PACKUSWB to shrink and concatenate them into a v4i32 again.
17396+
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL);
17397+
SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, PopCnt, Zeros);
17398+
SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, PopCnt, Zeros);
17399+
17400+
// Do the horizontal sums into two v2i64s.
17401+
Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
17402+
Low = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
17403+
DAG.getNode(ISD::BITCAST, DL, ByteVecVT, Low), Zeros);
17404+
High = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT,
17405+
DAG.getNode(ISD::BITCAST, DL, ByteVecVT, High), Zeros);
17406+
17407+
// Merge them together.
17408+
MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16);
17409+
PopCnt = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
17410+
DAG.getNode(ISD::BITCAST, DL, ShortVecVT, Low),
17411+
DAG.getNode(ISD::BITCAST, DL, ShortVecVT, High));
17412+
17413+
return DAG.getNode(ISD::BITCAST, DL, VT, PopCnt);
17414+
}
17415+
17416+
// To obtain pop count for each i16 element, shuffle the byte pop count to get
17417+
// even and odd elements into distinct vectors, add them and zero-extend each
17418+
// i8 elemento into i16, i.e.:
17419+
//
17420+
// B -> pop count per i8
17421+
// W -> pop count per i16
17422+
//
17423+
// Y = shuffle B, undef <0, 2, ...>
17424+
// Z = shuffle B, undef <1, 3, ...>
17425+
// W = zext <... x i8> to <... x i16> (Y + Z)
17426+
//
17427+
// Use a byte shuffle mask that matches PSHUFB.
17428+
//
17429+
assert(EltVT == MVT::i16 && "Unknown how to handle type");
17430+
SDValue Undef = DAG.getUNDEF(ByteVecVT);
17431+
SmallVector<int, 32> MaskA, MaskB;
17432+
17433+
// We can't use PSHUFB across lanes, so do the shuffle and sum inside each
17434+
// 128-bit lane, and then collapse the result.
17435+
int NumLanes = NumByteElts / 16;
17436+
assert(NumByteElts % 16 == 0 && "Must have 16-byte multiple vectors!");
17437+
for (int i = 0; i < NumLanes; ++i) {
17438+
for (int j = 0; j < 8; ++j) {
17439+
MaskA.push_back(i * 16 + j * 2);
17440+
MaskB.push_back(i * 16 + (j * 2) + 1);
17441+
}
17442+
MaskA.append((size_t)8, -1);
17443+
MaskB.append((size_t)8, -1);
17444+
}
17445+
17446+
SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskA);
17447+
SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskB);
17448+
PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB);
17449+
17450+
SmallVector<int, 4> Mask;
17451+
for (int i = 0; i < NumLanes; ++i)
17452+
Mask.push_back(2 * i);
17453+
Mask.append((size_t)NumLanes, -1);
17454+
17455+
PopCnt = DAG.getNode(ISD::BITCAST, DL, VecI64VT, PopCnt);
17456+
PopCnt =
17457+
DAG.getVectorShuffle(VecI64VT, DL, PopCnt, DAG.getUNDEF(VecI64VT), Mask);
17458+
PopCnt = DAG.getNode(ISD::BITCAST, DL, ByteVecVT, PopCnt);
17459+
17460+
// Zero extend i8s into i16 elts
17461+
SmallVector<int, 16> ZExtInRegMask;
17462+
for (int i = 0; i < NumByteElts / 2; ++i) {
17463+
ZExtInRegMask.push_back(i);
17464+
ZExtInRegMask.push_back(NumByteElts);
17465+
}
17466+
17467+
return DAG.getNode(
17468+
ISD::BITCAST, DL, VT,
17469+
DAG.getVectorShuffle(ByteVecVT, DL, PopCnt,
17470+
getZeroVector(ByteVecVT, Subtarget, DAG, DL),
17471+
ZExtInRegMask));
17472+
}
17473+
1733217474
static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
1733317475
const X86Subtarget *Subtarget,
1733417476
SelectionDAG &DAG) {
1733517477
MVT VT = Op.getSimpleValueType();
17336-
assert((VT.is128BitVector() || VT.is256BitVector()) &&
17337-
"CTPOP lowering only implemented for 128/256-bit wide vector types");
17478+
assert(VT.is128BitVector() &&
17479+
"Only 128-bit vector bitmath lowering supported.");
1733817480

1733917481
int VecSize = VT.getSizeInBits();
1734017482
int NumElts = VT.getVectorNumElements();
@@ -17344,9 +17486,9 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
1734417486
// This is the vectorized version of the "best" algorithm from
1734517487
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
1734617488
// with a minor tweak to use a series of adds + shifts instead of vector
17347-
// multiplications. Implemented for all integer vector types.
17348-
//
17349-
// FIXME: Use strategies from http://wm.ite.pl/articles/sse-popcount.html
17489+
// multiplications. Implemented for all integer vector types. We only use
17490+
// this when we don't have SSSE3 which allows a LUT-based lowering that is
17491+
// much faster, even faster than using native popcnt instructions.
1735017492

1735117493
SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
1735217494
EltVT);
@@ -17424,7 +17566,6 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
1742417566
return V;
1742517567
}
1742617568

17427-
1742817569
static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
1742917570
SelectionDAG &DAG) {
1743017571
MVT VT = Op.getSimpleValueType();
@@ -17434,6 +17575,12 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
1743417575
SDLoc DL(Op.getNode());
1743517576
SDValue Op0 = Op.getOperand(0);
1743617577

17578+
if (!Subtarget->hasSSSE3()) {
17579+
// We can't use the fast LUT approach, so fall back on vectorized bitmath.
17580+
assert(VT.is128BitVector() && "Only 128-bit vectors supported in SSE!");
17581+
return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
17582+
}
17583+
1743717584
if (VT.is256BitVector() && !Subtarget->hasInt256()) {
1743817585
unsigned NumElems = VT.getVectorNumElements();
1743917586

@@ -17442,11 +17589,11 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
1744217589
SDValue RHS = Extract128BitVector(Op0, NumElems/2, DAG, DL);
1744317590

1744417591
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
17445-
LowerVectorCTPOPBitmath(LHS, DL, Subtarget, DAG),
17446-
LowerVectorCTPOPBitmath(RHS, DL, Subtarget, DAG));
17592+
LowerVectorCTPOPInRegLUT(LHS, DL, Subtarget, DAG),
17593+
LowerVectorCTPOPInRegLUT(RHS, DL, Subtarget, DAG));
1744717594
}
1744817595

17449-
return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
17596+
return LowerVectorCTPOPInRegLUT(Op0, DL, Subtarget, DAG);
1745017597
}
1745117598

1745217599
static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
@@ -18149,6 +18296,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
1814918296
case X86ISD::VPERMI: return "X86ISD::VPERMI";
1815018297
case X86ISD::PMULUDQ: return "X86ISD::PMULUDQ";
1815118298
case X86ISD::PMULDQ: return "X86ISD::PMULDQ";
18299+
case X86ISD::PSADBW: return "X86ISD::PSADBW";
1815218300
case X86ISD::VASTART_SAVE_XMM_REGS: return "X86ISD::VASTART_SAVE_XMM_REGS";
1815318301
case X86ISD::VAARG_64: return "X86ISD::VAARG_64";
1815418302
case X86ISD::WIN_ALLOCA: return "X86ISD::WIN_ALLOCA";

lib/Target/X86/X86ISelLowering.h

+3
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ namespace llvm {
184184
/// Shuffle 16 8-bit values within a vector.
185185
PSHUFB,
186186

187+
/// Compute Sum of Absolute Differences.
188+
PSADBW,
189+
187190
/// Bitwise Logical AND NOT of Packed FP values.
188191
ANDNP,
189192

lib/Target/X86/X86InstrFragmentsSIMD.td

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def X86cmps : SDNode<"X86ISD::FSETCC", SDTX86Cmps>;
7878
def X86pshufb : SDNode<"X86ISD::PSHUFB",
7979
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
8080
SDTCisSameAs<0,2>]>>;
81+
def X86psadbw : SDNode<"X86ISD::PSADBW",
82+
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
83+
SDTCisSameAs<0,2>]>>;
8184
def X86andnp : SDNode<"X86ISD::ANDNP",
8285
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
8386
SDTCisSameAs<0,2>]>>;

lib/Target/X86/X86InstrSSE.td

+14
Original file line numberDiff line numberDiff line change
@@ -4053,6 +4053,20 @@ defm PAVGW : PDI_binop_all_int<0xE3, "pavgw", int_x86_sse2_pavg_w,
40534053
defm PSADBW : PDI_binop_all_int<0xF6, "psadbw", int_x86_sse2_psad_bw,
40544054
int_x86_avx2_psad_bw, SSE_PMADD, 1>;
40554055

4056+
let Predicates = [HasAVX2] in
4057+
def : Pat<(v32i8 (X86psadbw (v32i8 VR256:$src1),
4058+
(v32i8 VR256:$src2))),
4059+
(VPSADBWYrr VR256:$src2, VR256:$src1)>;
4060+
4061+
let Predicates = [HasAVX] in
4062+
def : Pat<(v16i8 (X86psadbw (v16i8 VR128:$src1),
4063+
(v16i8 VR128:$src2))),
4064+
(VPSADBWrr VR128:$src2, VR128:$src1)>;
4065+
4066+
def : Pat<(v16i8 (X86psadbw (v16i8 VR128:$src1),
4067+
(v16i8 VR128:$src2))),
4068+
(PSADBWrr VR128:$src2, VR128:$src1)>;
4069+
40564070
let Predicates = [HasAVX] in
40574071
defm VPMULUDQ : PDI_binop_rm2<0xF4, "vpmuludq", X86pmuludq, v2i64, v4i32, VR128,
40584072
loadv2i64, i128mem, SSE_INTMUL_ITINS_P, 1, 0>,

0 commit comments

Comments
 (0)