@@ -471,17 +471,36 @@ static bool isValidForAlternation(unsigned Opcode) {
471
471
return true;
472
472
}
473
473
474
+ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
475
+ unsigned BaseIndex = 0);
476
+
477
+ /// Checks if the provided operands of 2 cmp instructions are compatible, i.e.
478
+ /// compatible instructions or constants, or just some other regular values.
479
+ static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0,
480
+ Value *Op1) {
481
+ return (isConstant(BaseOp0) && isConstant(Op0)) ||
482
+ (isConstant(BaseOp1) && isConstant(Op1)) ||
483
+ (!isa<Instruction>(BaseOp0) && !isa<Instruction>(Op0) &&
484
+ !isa<Instruction>(BaseOp1) && !isa<Instruction>(Op1)) ||
485
+ getSameOpcode({BaseOp0, Op0}).getOpcode() ||
486
+ getSameOpcode({BaseOp1, Op1}).getOpcode();
487
+ }
488
+
474
489
/// \returns analysis of the Instructions in \p VL described in
475
490
/// InstructionsState, the Opcode that we suppose the whole list
476
491
/// could be vectorized even if its structure is diverse.
477
492
static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
478
- unsigned BaseIndex = 0 ) {
493
+ unsigned BaseIndex) {
479
494
// Make sure these are all Instructions.
480
495
if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); }))
481
496
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
482
497
483
498
bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
484
499
bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
500
+ bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
501
+ CmpInst::Predicate BasePred =
502
+ IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
503
+ : CmpInst::BAD_ICMP_PREDICATE;
485
504
unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
486
505
unsigned AltOpcode = Opcode;
487
506
unsigned AltIndex = BaseIndex;
@@ -514,6 +533,57 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
514
533
continue;
515
534
}
516
535
}
536
+ } else if (IsCmpOp && isa<CmpInst>(VL[Cnt])) {
537
+ auto *BaseInst = cast<Instruction>(VL[BaseIndex]);
538
+ auto *Inst = cast<Instruction>(VL[Cnt]);
539
+ Type *Ty0 = BaseInst->getOperand(0)->getType();
540
+ Type *Ty1 = Inst->getOperand(0)->getType();
541
+ if (Ty0 == Ty1) {
542
+ Value *BaseOp0 = BaseInst->getOperand(0);
543
+ Value *BaseOp1 = BaseInst->getOperand(1);
544
+ Value *Op0 = Inst->getOperand(0);
545
+ Value *Op1 = Inst->getOperand(1);
546
+ CmpInst::Predicate CurrentPred =
547
+ cast<CmpInst>(VL[Cnt])->getPredicate();
548
+ CmpInst::Predicate SwappedCurrentPred =
549
+ CmpInst::getSwappedPredicate(CurrentPred);
550
+ // Check for compatible operands. If the corresponding operands are not
551
+ // compatible - need to perform alternate vectorization.
552
+ if (InstOpcode == Opcode) {
553
+ if (BasePred == CurrentPred &&
554
+ areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1))
555
+ continue;
556
+ if (BasePred == SwappedCurrentPred &&
557
+ areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0))
558
+ continue;
559
+ if (E == 2 &&
560
+ (BasePred == CurrentPred || BasePred == SwappedCurrentPred))
561
+ continue;
562
+ auto *AltInst = cast<CmpInst>(VL[AltIndex]);
563
+ CmpInst::Predicate AltPred = AltInst->getPredicate();
564
+ Value *AltOp0 = AltInst->getOperand(0);
565
+ Value *AltOp1 = AltInst->getOperand(1);
566
+ // Check if operands are compatible with alternate operands.
567
+ if (AltPred == CurrentPred &&
568
+ areCompatibleCmpOps(AltOp0, AltOp1, Op0, Op1))
569
+ continue;
570
+ if (AltPred == SwappedCurrentPred &&
571
+ areCompatibleCmpOps(AltOp0, AltOp1, Op1, Op0))
572
+ continue;
573
+ }
574
+ if (BaseIndex == AltIndex) {
575
+ assert(isValidForAlternation(Opcode) &&
576
+ isValidForAlternation(InstOpcode) &&
577
+ "Cast isn't safe for alternation, logic needs to be updated!");
578
+ AltIndex = Cnt;
579
+ continue;
580
+ }
581
+ auto *AltInst = cast<CmpInst>(VL[AltIndex]);
582
+ CmpInst::Predicate AltPred = AltInst->getPredicate();
583
+ if (BasePred == CurrentPred || BasePred == SwappedCurrentPred ||
584
+ AltPred == CurrentPred || AltPred == SwappedCurrentPred)
585
+ continue;
586
+ }
517
587
} else if (InstOpcode == Opcode || InstOpcode == AltOpcode)
518
588
continue;
519
589
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
@@ -4354,9 +4424,41 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
4354
4424
LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n");
4355
4425
4356
4426
// Reorder operands if reordering would enable vectorization.
4357
- if (isa<BinaryOperator>(VL0)) {
4427
+ auto *CI = dyn_cast<CmpInst>(VL0);
4428
+ if (isa<BinaryOperator>(VL0) || CI) {
4358
4429
ValueList Left, Right;
4359
- reorderInputsAccordingToOpcode (VL, Left, Right, *DL, *SE, *this );
4430
+ if (!CI || all_of(VL, [](Value *V) {
4431
+ return cast<CmpInst>(V)->isCommutative();
4432
+ })) {
4433
+ reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this);
4434
+ } else {
4435
+ CmpInst::Predicate P0 = CI->getPredicate();
4436
+ CmpInst::Predicate AltP0 = cast<CmpInst>(S.AltOp)->getPredicate();
4437
+ CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0);
4438
+ Value *BaseOp0 = VL0->getOperand(0);
4439
+ Value *BaseOp1 = VL0->getOperand(1);
4440
+ // Collect operands - commute if it uses the swapped predicate or
4441
+ // alternate operation.
4442
+ for (Value *V : VL) {
4443
+ auto *Cmp = cast<CmpInst>(V);
4444
+ Value *LHS = Cmp->getOperand(0);
4445
+ Value *RHS = Cmp->getOperand(1);
4446
+ CmpInst::Predicate CurrentPred = CI->getPredicate();
4447
+ CmpInst::Predicate CurrentPredSwapped =
4448
+ CmpInst::getSwappedPredicate(CurrentPred);
4449
+ if (P0 == AltP0 || P0 == AltP0Swapped) {
4450
+ if ((P0 == CurrentPred &&
4451
+ !areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) ||
4452
+ (P0 == CurrentPredSwapped &&
4453
+ !areCompatibleCmpOps(BaseOp0, BaseOp1, RHS, LHS)))
4454
+ std::swap(LHS, RHS);
4455
+ } else if (!areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) {
4456
+ std::swap(LHS, RHS);
4457
+ }
4458
+ Left.push_back(LHS);
4459
+ Right.push_back(RHS);
4460
+ }
4461
+ }
4360
4462
TE->setOperand(0, Left);
4361
4463
TE->setOperand(1, Right);
4362
4464
buildTree_rec(Left, Depth + 1, {TE, 0});
@@ -5288,7 +5390,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
5288
5390
((Instruction::isBinaryOp(E->getOpcode()) &&
5289
5391
Instruction::isBinaryOp(E->getAltOpcode())) ||
5290
5392
(Instruction::isCast(E->getOpcode()) &&
5291
- Instruction::isCast (E->getAltOpcode ()))) &&
5393
+ Instruction::isCast(E->getAltOpcode())) ||
5394
+ (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
5292
5395
"Invalid Shuffle Vector Operand");
5293
5396
InstructionCost ScalarCost = 0;
5294
5397
if (NeedToShuffleReuses) {
@@ -5336,6 +5439,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
5336
5439
VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, CostKind);
5337
5440
VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy,
5338
5441
CostKind);
5442
+ } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
5443
+ VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy,
5444
+ Builder.getInt1Ty(),
5445
+ CI0->getPredicate(), CostKind, VL0);
5446
+ VecCost += TTI->getCmpSelInstrCost(
5447
+ E->getOpcode(), ScalarTy, Builder.getInt1Ty(),
5448
+ cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
5449
+ E->getAltOp());
5339
5450
} else {
5340
5451
Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
5341
5452
Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
@@ -5352,6 +5463,29 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
5352
5463
E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
5353
5464
[E](Instruction *I) {
5354
5465
assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
5466
+ if (auto *CI0 = dyn_cast<CmpInst>(E->getMainOp())) {
5467
+ auto *AltCI0 = cast<CmpInst>(E->getAltOp());
5468
+ auto *CI = cast<CmpInst>(I);
5469
+ CmpInst::Predicate P0 = CI0->getPredicate();
5470
+ CmpInst::Predicate AltP0 = AltCI0->getPredicate();
5471
+ CmpInst::Predicate AltP0Swapped =
5472
+ CmpInst::getSwappedPredicate(AltP0);
5473
+ CmpInst::Predicate CurrentPred = CI->getPredicate();
5474
+ CmpInst::Predicate CurrentPredSwapped =
5475
+ CmpInst::getSwappedPredicate(CurrentPred);
5476
+ if (P0 == AltP0 || P0 == AltP0Swapped) {
5477
+ // Alternate cmps have same/swapped predicate as main cmps but
5478
+ // different order of compatible operands.
5479
+ return !(
5480
+ (P0 == CurrentPred &&
5481
+ areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1),
5482
+ I->getOperand(0), I->getOperand(1))) ||
5483
+ (P0 == CurrentPredSwapped &&
5484
+ areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1),
5485
+ I->getOperand(1), I->getOperand(0))));
5486
+ }
5487
+ return CurrentPred != P0 && CurrentPredSwapped != P0;
5488
+ }
5355
5489
return I->getOpcode() == E->getAltOpcode();
5356
5490
},
5357
5491
Mask);
@@ -6834,11 +6968,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
6834
6968
((Instruction::isBinaryOp(E->getOpcode()) &&
6835
6969
Instruction::isBinaryOp(E->getAltOpcode())) ||
6836
6970
(Instruction::isCast(E->getOpcode()) &&
6837
- Instruction::isCast (E->getAltOpcode ()))) &&
6971
+ Instruction::isCast(E->getAltOpcode())) ||
6972
+ (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
6838
6973
"Invalid Shuffle Vector Operand");
6839
6974
6840
6975
Value *LHS = nullptr, *RHS = nullptr;
6841
- if (Instruction::isBinaryOp (E->getOpcode ())) {
6976
+ if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0) ) {
6842
6977
setInsertPointAfterBundle(E);
6843
6978
LHS = vectorizeTree(E->getOperand(0));
6844
6979
RHS = vectorizeTree(E->getOperand(1));
@@ -6858,6 +6993,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
6858
6993
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS);
6859
6994
V1 = Builder.CreateBinOp(
6860
6995
static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS);
6996
+ } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
6997
+ V0 = Builder.CreateCmp(CI0->getPredicate(), LHS, RHS);
6998
+ auto *AltCI = cast<CmpInst>(E->getAltOp());
6999
+ CmpInst::Predicate AltPred = AltCI->getPredicate();
7000
+ unsigned AltIdx =
7001
+ std::distance(E->Scalars.begin(), find(E->Scalars, AltCI));
7002
+ if (AltCI->getOperand(0) != E->getOperand(0)[AltIdx])
7003
+ AltPred = CmpInst::getSwappedPredicate(AltPred);
7004
+ V1 = Builder.CreateCmp(AltPred, LHS, RHS);
6861
7005
} else {
6862
7006
V0 = Builder.CreateCast(
6863
7007
static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
@@ -6882,6 +7026,29 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
6882
7026
E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices,
6883
7027
[E](Instruction *I) {
6884
7028
assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
7029
+ if (auto *CI0 = dyn_cast<CmpInst>(E->getMainOp())) {
7030
+ auto *AltCI0 = cast<CmpInst>(E->getAltOp());
7031
+ auto *CI = cast<CmpInst>(I);
7032
+ CmpInst::Predicate P0 = CI0->getPredicate();
7033
+ CmpInst::Predicate AltP0 = AltCI0->getPredicate();
7034
+ CmpInst::Predicate AltP0Swapped =
7035
+ CmpInst::getSwappedPredicate(AltP0);
7036
+ CmpInst::Predicate CurrentPred = CI->getPredicate();
7037
+ CmpInst::Predicate CurrentPredSwapped =
7038
+ CmpInst::getSwappedPredicate(CurrentPred);
7039
+ if (P0 == AltP0 || P0 == AltP0Swapped) {
7040
+ // Alternate cmps have same/swapped predicate as main cmps but
7041
+ // different order of compatible operands.
7042
+ return !(
7043
+ (P0 == CurrentPred &&
7044
+ areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1),
7045
+ I->getOperand(0), I->getOperand(1))) ||
7046
+ (P0 == CurrentPredSwapped &&
7047
+ areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1),
7048
+ I->getOperand(1), I->getOperand(0))));
7049
+ }
7050
+ return CurrentPred != P0 && CurrentPredSwapped != P0;
7051
+ }
6885
7052
return I->getOpcode() == E->getAltOpcode();
6886
7053
},
6887
7054
Mask, &OpScalars, &AltScalars);
0 commit comments