@@ -29233,28 +29233,6 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
29233
29233
return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones);
29234
29234
}
29235
29235
29236
- static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
29237
- TargetLowering::DAGCombinerInfo &DCI,
29238
- const X86Subtarget &Subtarget) {
29239
- if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
29240
- return Cmp;
29241
-
29242
- if (DCI.isBeforeLegalizeOps())
29243
- return SDValue();
29244
-
29245
- if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
29246
- return RV;
29247
-
29248
- if (Subtarget.hasCMov())
29249
- if (SDValue RV = combineIntegerAbs(N, DAG))
29250
- return RV;
29251
-
29252
- if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
29253
- return FPLogic;
29254
-
29255
- return SDValue();
29256
- }
29257
-
29258
29236
/// This function detects the AVG pattern between vectors of unsigned i8/i16,
29259
29237
/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
29260
29238
/// X86ISD::AVG instruction.
@@ -30363,12 +30341,68 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
30363
30341
return combineVectorTruncation(N, DAG, Subtarget);
30364
30342
}
30365
30343
30344
+ /// Returns the negated value if the node \p N flips sign of FP value.
30345
+ ///
30346
+ /// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000).
30347
+ /// AVX512F does not have FXOR, so FNEG is lowered as
30348
+ /// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
30349
+ /// In this case we go though all bitcasts.
30350
+ static SDValue isFNEG(SDNode *N) {
30351
+ if (N->getOpcode() == ISD::FNEG)
30352
+ return N->getOperand(0);
30353
+
30354
+ SDValue Op = peekThroughBitcasts(SDValue(N, 0));
30355
+ if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR)
30356
+ return SDValue();
30357
+
30358
+ SDValue Op1 = peekThroughBitcasts(Op.getOperand(1));
30359
+ if (!Op1.getValueType().isFloatingPoint())
30360
+ return SDValue();
30361
+
30362
+ SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
30363
+
30364
+ unsigned EltBits = Op1.getValueType().getScalarSizeInBits();
30365
+ auto isSignBitValue = [&](const ConstantFP *C) {
30366
+ return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
30367
+ };
30368
+
30369
+ // There is more than one way to represent the same constant on
30370
+ // the different X86 targets. The type of the node may also depend on size.
30371
+ // - load scalar value and broadcast
30372
+ // - BUILD_VECTOR node
30373
+ // - load from a constant pool.
30374
+ // We check all variants here.
30375
+ if (Op1.getOpcode() == X86ISD::VBROADCAST) {
30376
+ if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
30377
+ if (isSignBitValue(cast<ConstantFP>(C)))
30378
+ return Op0;
30379
+
30380
+ } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
30381
+ if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
30382
+ if (isSignBitValue(CN->getConstantFPValue()))
30383
+ return Op0;
30384
+
30385
+ } else if (auto *C = getTargetConstantFromNode(Op1)) {
30386
+ if (C->getType()->isVectorTy()) {
30387
+ if (auto *SplatV = C->getSplatValue())
30388
+ if (isSignBitValue(cast<ConstantFP>(SplatV)))
30389
+ return Op0;
30390
+ } else if (auto *FPConst = dyn_cast<ConstantFP>(C))
30391
+ if (isSignBitValue(FPConst))
30392
+ return Op0;
30393
+ }
30394
+ return SDValue();
30395
+ }
30396
+
30366
30397
/// Do target-specific dag combines on floating point negations.
30367
30398
static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
30368
30399
const X86Subtarget &Subtarget) {
30369
- EVT VT = N->getValueType(0);
30400
+ EVT OrigVT = N->getValueType(0);
30401
+ SDValue Arg = isFNEG(N);
30402
+ assert(Arg.getNode() && "N is expected to be an FNEG node");
30403
+
30404
+ EVT VT = Arg.getValueType();
30370
30405
EVT SVT = VT.getScalarType();
30371
- SDValue Arg = N->getOperand(0);
30372
30406
SDLoc DL(N);
30373
30407
30374
30408
// Let legalize expand this if it isn't a legal type yet.
@@ -30381,40 +30415,30 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
30381
30415
if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) &&
30382
30416
Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) {
30383
30417
SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
30384
- return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30385
- Arg.getOperand(1), Zero);
30418
+ SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30419
+ Arg.getOperand(1), Zero);
30420
+ return DAG.getBitcast(OrigVT, NewNode);
30386
30421
}
30387
30422
30388
30423
// If we're negating a FMA node, then we can adjust the
30389
30424
// instruction to include the extra negation.
30425
+ unsigned NewOpcode = 0;
30390
30426
if (Arg.hasOneUse()) {
30391
30427
switch (Arg.getOpcode()) {
30392
- case X86ISD::FMADD:
30393
- return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30394
- Arg.getOperand(1), Arg.getOperand(2));
30395
- case X86ISD::FMSUB:
30396
- return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0),
30397
- Arg.getOperand(1), Arg.getOperand(2));
30398
- case X86ISD::FNMADD:
30399
- return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0),
30400
- Arg.getOperand(1), Arg.getOperand(2));
30401
- case X86ISD::FNMSUB:
30402
- return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0),
30403
- Arg.getOperand(1), Arg.getOperand(2));
30404
- case X86ISD::FMADD_RND:
30405
- return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0),
30406
- Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30407
- case X86ISD::FMSUB_RND:
30408
- return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0),
30409
- Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30410
- case X86ISD::FNMADD_RND:
30411
- return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0),
30412
- Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30413
- case X86ISD::FNMSUB_RND:
30414
- return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0),
30415
- Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30428
+ case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break;
30429
+ case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break;
30430
+ case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break;
30431
+ case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break;
30432
+ case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break;
30433
+ case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break;
30434
+ case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break;
30435
+ case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break;
30416
30436
}
30417
30437
}
30438
+ if (NewOpcode)
30439
+ return DAG.getBitcast(OrigVT, DAG.getNode(NewOpcode, DL, VT,
30440
+ Arg.getNode()->ops()));
30441
+
30418
30442
return SDValue();
30419
30443
}
30420
30444
@@ -30442,42 +30466,28 @@ static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG,
30442
30466
return SDValue();
30443
30467
}
30444
30468
30445
- /// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000).
30446
- bool isFNEG(const SDNode *N) {
30447
- if (N->getOpcode() == ISD::FNEG)
30448
- return true;
30469
+ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
30470
+ TargetLowering::DAGCombinerInfo &DCI,
30471
+ const X86Subtarget &Subtarget) {
30472
+ if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
30473
+ return Cmp;
30449
30474
30450
- if (N->getOpcode() == X86ISD::FXOR) {
30451
- unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits();
30452
- SDValue Op1 = N->getOperand(1);
30475
+ if (DCI.isBeforeLegalizeOps())
30476
+ return SDValue();
30453
30477
30454
- auto isSignBitValue = [&](const ConstantFP *C) {
30455
- return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
30456
- };
30478
+ if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
30479
+ return RV;
30457
30480
30458
- // There is more than one way to represent the same constant on
30459
- // the different X86 targets. The type of the node may also depend on size.
30460
- // - load scalar value and broadcast
30461
- // - BUILD_VECTOR node
30462
- // - load from a constant pool.
30463
- // We check all variants here.
30464
- if (Op1.getOpcode() == X86ISD::VBROADCAST) {
30465
- if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
30466
- return isSignBitValue(cast<ConstantFP>(C));
30481
+ if (Subtarget.hasCMov())
30482
+ if (SDValue RV = combineIntegerAbs(N, DAG))
30483
+ return RV;
30467
30484
30468
- } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
30469
- if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
30470
- return isSignBitValue(CN->getConstantFPValue());
30485
+ if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
30486
+ return FPLogic;
30471
30487
30472
- } else if (auto *C = getTargetConstantFromNode(Op1)) {
30473
- if (C->getType()->isVectorTy()) {
30474
- if (auto *SplatV = C->getSplatValue())
30475
- return isSignBitValue(cast<ConstantFP>(SplatV));
30476
- } else if (auto *FPConst = dyn_cast<ConstantFP>(C))
30477
- return isSignBitValue(FPConst);
30478
- }
30479
- }
30480
- return false;
30488
+ if (isFNEG(N))
30489
+ return combineFneg(N, DAG, Subtarget);
30490
+ return SDValue();
30481
30491
}
30482
30492
30483
30493
/// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes.
@@ -30907,18 +30917,20 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
30907
30917
SDValue B = N->getOperand(1);
30908
30918
SDValue C = N->getOperand(2);
30909
30919
30910
- bool NegA = isFNEG(A.getNode());
30911
- bool NegB = isFNEG(B.getNode());
30912
- bool NegC = isFNEG(C.getNode());
30920
+ auto invertIfNegative = [](SDValue &V) {
30921
+ if (SDValue NegVal = isFNEG(V.getNode())) {
30922
+ V = NegVal;
30923
+ return true;
30924
+ }
30925
+ return false;
30926
+ };
30927
+
30928
+ bool NegA = invertIfNegative(A);
30929
+ bool NegB = invertIfNegative(B);
30930
+ bool NegC = invertIfNegative(C);
30913
30931
30914
30932
// Negative multiplication when NegA xor NegB
30915
30933
bool NegMul = (NegA != NegB);
30916
- if (NegA)
30917
- A = A.getOperand(0);
30918
- if (NegB)
30919
- B = B.getOperand(0);
30920
- if (NegC)
30921
- C = C.getOperand(0);
30922
30934
30923
30935
unsigned NewOpcode;
30924
30936
if (!NegMul)
0 commit comments