@@ -846,8 +846,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
846
846
// know to perform better than using the popcnt instructions on each vector
847
847
// element. If popcnt isn't supported, always provide the custom version.
848
848
if (!Subtarget->hasPOPCNT()) {
849
- setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
850
849
setOperationAction(ISD::CTPOP, MVT::v2i64, Custom);
850
+ setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
851
+ setOperationAction(ISD::CTPOP, MVT::v8i16, Custom);
852
+ setOperationAction(ISD::CTPOP, MVT::v16i8, Custom);
851
853
}
852
854
853
855
// Custom lower build_vector, vector_shuffle, and extract_vector_elt.
@@ -17327,141 +17329,131 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
17327
17329
return SDValue();
17328
17330
}
17329
17331
17330
- static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
17331
- SelectionDAG &DAG) {
17332
- SDNode *Node = Op.getNode();
17333
- SDLoc dl(Node);
17334
-
17335
- Op = Op.getOperand(0);
17336
- EVT VT = Op.getValueType();
17332
+ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
17333
+ const X86Subtarget *Subtarget,
17334
+ SelectionDAG &DAG) {
17335
+ MVT VT = Op.getSimpleValueType();
17337
17336
assert((VT.is128BitVector() || VT.is256BitVector()) &&
17338
17337
"CTPOP lowering only implemented for 128/256-bit wide vector types");
17339
17338
17340
- unsigned NumElts = VT.getVectorNumElements();
17341
- EVT EltVT = VT.getVectorElementType();
17342
- unsigned Len = EltVT.getSizeInBits();
17339
+ int VecSize = VT.getSizeInBits();
17340
+ int NumElts = VT.getVectorNumElements();
17341
+ MVT EltVT = VT.getVectorElementType();
17342
+ int Len = EltVT.getSizeInBits();
17343
17343
17344
17344
// This is the vectorized version of the "best" algorithm from
17345
17345
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
17346
17346
// with a minor tweak to use a series of adds + shifts instead of vector
17347
- // multiplications. Implemented for the v2i64, v4i64, v4i32, v8i32 types:
17348
- //
17349
- // v2i64, v4i64, v4i32 => Only profitable w/ popcnt disabled
17350
- // v8i32 => Always profitable
17347
+ // multiplications. Implemented for all integer vector types.
17351
17348
//
17352
- // FIXME: There a couple of possible improvements:
17353
- //
17354
- // 1) Support for i8 and i16 vectors (needs measurements if popcnt enabled).
17355
- // 2) Use strategies from http://wm.ite.pl/articles/sse-popcount.html
17356
- //
17357
- assert(EltVT.isInteger() && (Len == 32 || Len == 64) && Len % 8 == 0 &&
17358
- "CTPOP not implemented for this vector element type.");
17359
-
17360
- // X86 canonicalize ANDs to vXi64, generate the appropriate bitcasts to avoid
17361
- // extra legalization.
17362
- bool NeedsBitcast = EltVT == MVT::i32;
17363
- MVT BitcastVT = VT.is256BitVector() ? MVT::v4i64 : MVT::v2i64;
17349
+ // FIXME: Use strategies from http://wm.ite.pl/articles/sse-popcount.html
17364
17350
17365
- SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl ,
17351
+ SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL ,
17366
17352
EltVT);
17367
- SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl ,
17353
+ SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL ,
17368
17354
EltVT);
17369
- SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl ,
17355
+ SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL ,
17370
17356
EltVT);
17371
17357
17358
+ SDValue V = Op;
17359
+
17372
17360
// v = v - ((v >> 1) & 0x55555555...)
17373
- SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, dl, EltVT));
17374
- SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ones);
17375
- SDValue Srl = DAG.getNode(ISD::SRL, dl, VT, Op, OnesV);
17376
- if (NeedsBitcast)
17377
- Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
17361
+ SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
17362
+ SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
17363
+ SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
17378
17364
17379
17365
SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
17380
- SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask55);
17381
- if (NeedsBitcast)
17382
- M55 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M55);
17366
+ SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
17367
+ SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
17383
17368
17384
- SDValue And = DAG.getNode(ISD::AND, dl, Srl.getValueType(), Srl, M55);
17385
- if (VT != And.getValueType())
17386
- And = DAG.getNode(ISD::BITCAST, dl, VT, And);
17387
- SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, Op, And);
17369
+ V = DAG.getNode(ISD::SUB, DL, VT, V, And);
17388
17370
17389
17371
// v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
17390
17372
SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
17391
- SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask33);
17392
- SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, dl, EltVT));
17393
- SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Twos);
17373
+ SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
17374
+ SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
17394
17375
17395
- Srl = DAG.getNode(ISD::SRL, dl, VT, Sub, TwosV);
17396
- if (NeedsBitcast) {
17397
- Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
17398
- M33 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M33);
17399
- Sub = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Sub);
17400
- }
17376
+ SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
17377
+ SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
17378
+ Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
17379
+ SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
17401
17380
17402
- SDValue AndRHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Srl, M33);
17403
- SDValue AndLHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Sub, M33);
17404
- if (VT != AndRHS.getValueType()) {
17405
- AndRHS = DAG.getNode(ISD::BITCAST, dl, VT, AndRHS);
17406
- AndLHS = DAG.getNode(ISD::BITCAST, dl, VT, AndLHS);
17407
- }
17408
- SDValue Add = DAG.getNode(ISD::ADD, dl, VT, AndLHS, AndRHS);
17381
+ V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
17409
17382
17410
17383
// v = (v + (v >> 4)) & 0x0F0F0F0F...
17411
- SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, dl , EltVT));
17412
- SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, dl , VT, Fours);
17413
- Srl = DAG.getNode(ISD::SRL, dl , VT, Add , FoursV);
17414
- Add = DAG.getNode(ISD::ADD, dl , VT, Add , Srl);
17384
+ SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL , EltVT));
17385
+ SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL , VT, Fours);
17386
+ Srl = DAG.getNode(ISD::SRL, DL , VT, V , FoursV);
17387
+ SDValue Add = DAG.getNode(ISD::ADD, DL , VT, V , Srl);
17415
17388
17416
17389
SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
17417
- SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask0F);
17418
- if (NeedsBitcast) {
17419
- Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
17420
- M0F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M0F);
17421
- }
17422
- And = DAG.getNode(ISD::AND, dl, M0F.getValueType(), Add, M0F);
17423
- if (VT != And.getValueType())
17424
- And = DAG.getNode(ISD::BITCAST, dl, VT, And);
17390
+ SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
17425
17391
17426
- // The algorithm mentioned above uses:
17427
- // v = (v * 0x01010101...) >> (Len - 8)
17428
- //
17429
- // Change it to use vector adds + vector shifts which yield faster results on
17430
- // Haswell than using vector integer multiplication.
17431
- //
17432
- // For i32 elements:
17433
- // v = v + (v >> 8)
17434
- // v = v + (v >> 16)
17435
- //
17436
- // For i64 elements:
17437
- // v = v + (v >> 8)
17438
- // v = v + (v >> 16)
17439
- // v = v + (v >> 32)
17392
+ V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
17393
+
17394
+ // At this point, V contains the byte-wise population count, and we are
17395
+ // merely doing a horizontal sum if necessary to get the wider element
17396
+ // counts.
17440
17397
//
17441
- Add = And;
17398
+ // FIXME: There is a different lowering strategy above for the horizontal sum
17399
+ // of byte-wise population counts. This one and that one should be merged,
17400
+ // using the fastest of the two for each size.
17401
+ MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
17402
+ MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
17403
+ V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
17442
17404
SmallVector<SDValue, 8> Csts;
17443
- for (unsigned i = 8; i <= Len/2; i *= 2) {
17444
- Csts.assign(NumElts, DAG.getConstant(i, dl, EltVT));
17445
- SDValue CstsV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Csts);
17446
- Srl = DAG.getNode(ISD::SRL, dl, VT, Add, CstsV);
17447
- Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
17448
- Csts.clear();
17405
+ assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
17406
+ assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
17407
+ for (int i = Len; i > 8; i /= 2) {
17408
+ Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
17409
+ SDValue Shl = DAG.getNode(
17410
+ ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
17411
+ DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
17412
+ V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
17413
+ DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
17414
+ }
17415
+
17416
+ // The high byte now contains the sum of the element bytes. Shift it right
17417
+ // (if needed) to make it the low byte.
17418
+ V = DAG.getNode(ISD::BITCAST, DL, VT, V);
17419
+ if (Len > 8) {
17420
+ Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
17421
+ V = DAG.getNode(ISD::SRL, DL, VT, V,
17422
+ DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
17449
17423
}
17424
+ return V;
17425
+ }
17450
17426
17451
- // The result is on the least significant 6-bits on i32 and 7-bits on i64.
17452
- SDValue Cst3F = DAG.getConstant(APInt(Len, Len == 32 ? 0x3F : 0x7F), dl,
17453
- EltVT);
17454
- SmallVector<SDValue, 8> Cst3FV(NumElts, Cst3F);
17455
- SDValue M3F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Cst3FV);
17456
- if (NeedsBitcast) {
17457
- Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
17458
- M3F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M3F);
17427
+
17428
+ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
17429
+ SelectionDAG &DAG) {
17430
+ MVT VT = Op.getSimpleValueType();
17431
+ // FIXME: Need to add AVX-512 support here!
17432
+ assert((VT.is256BitVector() || VT.is128BitVector()) &&
17433
+ "Unknown CTPOP type to handle");
17434
+ SDLoc DL(Op.getNode());
17435
+ SDValue Op0 = Op.getOperand(0);
17436
+
17437
+ if (VT.is256BitVector() && !Subtarget->hasInt256()) {
17438
+ unsigned NumElems = VT.getVectorNumElements();
17439
+
17440
+ // Extract each 128-bit vector, compute pop count and concat the result.
17441
+ SDValue LHS = Extract128BitVector(Op0, 0, DAG, DL);
17442
+ SDValue RHS = Extract128BitVector(Op0, NumElems/2, DAG, DL);
17443
+
17444
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
17445
+ LowerVectorCTPOPBitmath(LHS, DL, Subtarget, DAG),
17446
+ LowerVectorCTPOPBitmath(RHS, DL, Subtarget, DAG));
17459
17447
}
17460
- And = DAG.getNode(ISD::AND, dl, M3F.getValueType(), Add, M3F);
17461
- if (VT != And.getValueType())
17462
- And = DAG.getNode(ISD::BITCAST, dl, VT, And);
17463
17448
17464
- return And;
17449
+ return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
17450
+ }
17451
+
17452
+ static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
17453
+ SelectionDAG &DAG) {
17454
+ assert(Op.getValueType().isVector() &&
17455
+ "We only do custom lowering for vector population count.");
17456
+ return LowerVectorCTPOP(Op, Subtarget, DAG);
17465
17457
}
17466
17458
17467
17459
static SDValue LowerLOAD_SUB(SDValue Op, SelectionDAG &DAG) {
0 commit comments