@@ -6456,6 +6456,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
6456
6456
// a predicated block since it will become a fall-through, although we
6457
6457
// may decide in the future to call TTI for all branches.
6458
6458
}
6459
+ case Instruction::Switch: {
6460
+ if (VF.isScalar ())
6461
+ return TTI.getCFInstrCost (Instruction::Switch, CostKind);
6462
+ auto *Switch = cast<SwitchInst>(I);
6463
+ return Switch->getNumCases () *
6464
+ TTI.getCmpSelInstrCost (
6465
+ Instruction::ICmp,
6466
+ ToVectorTy (Switch->getCondition ()->getType (), VF),
6467
+ ToVectorTy (Type::getInt1Ty (I->getContext ()), VF),
6468
+ CmpInst::BAD_ICMP_PREDICATE, CostKind);
6469
+ }
6459
6470
case Instruction::PHI: {
6460
6471
auto *Phi = cast<PHINode>(I);
6461
6472
@@ -7843,38 +7854,58 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
7843
7854
VPValue *SrcMask = getBlockInMask (Src);
7844
7855
7845
7856
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator ())) {
7846
- // Create mask where the terminator in Src is a switch. We need to handle 2
7847
- // separate cases:
7848
- // 1. Dst is not the default desintation. Dst is reached if any of the cases
7857
+ assert (!OrigLoop->isLoopExiting (Src) &&
7858
+ all_of (successors (Src),
7859
+ [this ](BasicBlock *Succ) {
7860
+ return OrigLoop->getHeader () != Succ;
7861
+ }) &&
7862
+ " unsupported switch either exiting loop or continuing to header" );
7863
+ // Create masks where the terminator in Src is a switch. We create mask for
7864
+ // all edges at the same time. This is more efficient, as we can create and
7865
+ // collect compares for all cases once.
7866
+ VPValue *Cond = getVPValueOrAddLiveIn (SI->getCondition (), Plan);
7867
+ BasicBlock *DefaultDst = SI->getDefaultDest ();
7868
+ MapVector<BasicBlock *, SmallVector<VPValue *>> Map;
7869
+ for (auto &C : SI->cases ()) {
7870
+ auto I = Map.insert ({C.getCaseSuccessor (), {}});
7871
+ VPValue *V = getVPValueOrAddLiveIn (C.getCaseValue (), Plan);
7872
+ I.first ->second .push_back (Builder.createICmp (CmpInst::ICMP_EQ, Cond, V));
7873
+ }
7874
+
7875
+ // We need to handle 2 separate cases:
7876
+ // 1. Dst is not the default destination. Dst is reached if any of the cases
7849
7877
// with destination == Dst are taken. Join the conditions for each case
7850
7878
// where destination == Dst using a logical OR.
7879
+ for (const auto &[Dst, Conds] : Map) {
7880
+ VPValue *Mask = Conds[0 ];
7881
+ for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front ())
7882
+ Mask = Builder.createOr (Mask, V);
7883
+ if (SrcMask)
7884
+ Mask = Builder.createLogicalAnd (SrcMask, Mask);
7885
+ EdgeMaskCache[{Src, Dst}] = Mask;
7886
+ }
7887
+
7851
7888
// 2. Dst is the default destination. Dst is reached if none of the cases
7852
7889
// with destination != Dst are taken. Join the conditions for each case
7853
7890
// where the destination is != Dst using a logical OR and negate it.
7854
- VPValue *Mask = nullptr ;
7855
- VPValue *Cond = getVPValueOrAddLiveIn (SI->getCondition (), Plan);
7856
- bool IsDefault = SI->getDefaultDest () == Dst;
7857
- for (auto &C : SI->cases ()) {
7858
- if (IsDefault) {
7859
- if (C.getCaseSuccessor () == Dst)
7860
- continue ;
7861
- } else if (C.getCaseSuccessor () != Dst)
7891
+ VPValue *DefaultMask = nullptr ;
7892
+ for (const auto &[Dst, Conds] : Map) {
7893
+ if (Dst == DefaultDst)
7894
+ continue ;
7895
+ if (!DefaultMask) {
7896
+ DefaultMask = EdgeMaskCache[{Src, Dst}];
7862
7897
continue ;
7863
-
7864
- VPValue *Eq = EdgeMaskCache.lookup ({Src, C.getCaseSuccessor ()});
7865
- if (!Eq) {
7866
- VPValue *V = getVPValueOrAddLiveIn (C.getCaseValue (), Plan);
7867
- Eq = Builder.createICmp (CmpInst::ICMP_EQ, Cond, V);
7868
7898
}
7869
- if (Mask)
7870
- Mask = Builder.createOr (Mask, Eq);
7871
- else
7872
- Mask = Eq;
7899
+ DefaultMask = Builder.createOr (DefaultMask, EdgeMaskCache[{Src, Dst}]);
7900
+ }
7901
+ if (DefaultMask) {
7902
+ DefaultMask = Builder.createNot (DefaultMask);
7903
+ if (SrcMask)
7904
+ DefaultMask = Builder.createLogicalAnd (SrcMask, DefaultMask);
7873
7905
}
7874
- if (IsDefault)
7875
- Mask = Builder.createNot (Mask);
7876
- assert (Mask && " mask must be created" );
7877
- return EdgeMaskCache[Edge] = Mask;
7906
+ EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7907
+ assert (EdgeMaskCache.contains (Edge) && " Mask for Edge not created?" );
7908
+ return EdgeMaskCache[Edge];
7878
7909
}
7879
7910
7880
7911
// The terminator has to be a branch inst!
0 commit comments