Skip to content

Commit cb4461c

Browse files
committed
!fixup address comments, thanks!
1 parent 150b51f commit cb4461c

File tree

5 files changed

+242
-554
lines changed

5 files changed

+242
-554
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,9 +1339,18 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
13391339

13401340
// Collect the blocks that need predication.
13411341
for (BasicBlock *BB : TheLoop->blocks()) {
1342-
// We don't support switch statements inside loops.
1343-
if (!isa<BranchInst, SwitchInst>(BB->getTerminator())) {
1344-
reportVectorizationFailure("Loop contains an unsupported termaintor",
1342+
// We support only branches and switch statements as terminators inside the
1343+
// loop.
1344+
if (isa<SwitchInst>(BB->getTerminator())) {
1345+
if (TheLoop->isLoopExiting(BB)) {
1346+
reportVectorizationFailure("Loop contains an unsupported switch",
1347+
"loop contains an unsupported switch",
1348+
"LoopContainsUnsupportedSwitch", ORE,
1349+
TheLoop, BB->getTerminator());
1350+
return false;
1351+
}
1352+
} else if (!isa<BranchInst>(BB->getTerminator())) {
1353+
reportVectorizationFailure("Loop contains an unsupported terminator",
13451354
"loop contains an unsupported terminator",
13461355
"LoopContainsUnsupportedTerminator", ORE,
13471356
TheLoop, BB->getTerminator());

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6456,6 +6456,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
64566456
// a predicated block since it will become a fall-through, although we
64576457
// may decide in the future to call TTI for all branches.
64586458
}
6459+
case Instruction::Switch: {
6460+
if (VF.isScalar())
6461+
return TTI.getCFInstrCost(Instruction::Switch, CostKind);
6462+
auto *Switch = cast<SwitchInst>(I);
6463+
return Switch->getNumCases() *
6464+
TTI.getCmpSelInstrCost(
6465+
Instruction::ICmp,
6466+
ToVectorTy(Switch->getCondition()->getType(), VF),
6467+
ToVectorTy(Type::getInt1Ty(I->getContext()), VF),
6468+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
6469+
}
64596470
case Instruction::PHI: {
64606471
auto *Phi = cast<PHINode>(I);
64616472

@@ -7843,38 +7854,58 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78437854
VPValue *SrcMask = getBlockInMask(Src);
78447855

78457856
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
7846-
// Create mask where the terminator in Src is a switch. We need to handle 2
7847-
// separate cases:
7848-
// 1. Dst is not the default desintation. Dst is reached if any of the cases
7857+
assert(!OrigLoop->isLoopExiting(Src) &&
7858+
all_of(successors(Src),
7859+
[this](BasicBlock *Succ) {
7860+
return OrigLoop->getHeader() != Succ;
7861+
}) &&
7862+
"unsupported switch either exiting loop or continuing to header");
7863+
// Create masks where the terminator in Src is a switch. We create mask for
7864+
// all edges at the same time. This is more efficient, as we can create and
7865+
// collect compares for all cases once.
7866+
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7867+
BasicBlock *DefaultDst = SI->getDefaultDest();
7868+
MapVector<BasicBlock *, SmallVector<VPValue *>> Map;
7869+
for (auto &C : SI->cases()) {
7870+
auto I = Map.insert({C.getCaseSuccessor(), {}});
7871+
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7872+
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7873+
}
7874+
7875+
// We need to handle 2 separate cases:
7876+
// 1. Dst is not the default destination. Dst is reached if any of the cases
78497877
// with destination == Dst are taken. Join the conditions for each case
78507878
// where destination == Dst using a logical OR.
7879+
for (const auto &[Dst, Conds] : Map) {
7880+
VPValue *Mask = Conds[0];
7881+
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7882+
Mask = Builder.createOr(Mask, V);
7883+
if (SrcMask)
7884+
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7885+
EdgeMaskCache[{Src, Dst}] = Mask;
7886+
}
7887+
78517888
// 2. Dst is the default destination. Dst is reached if none of the cases
78527889
// with destination != Dst are taken. Join the conditions for each case
78537890
// where the destination is != Dst using a logical OR and negate it.
7854-
VPValue *Mask = nullptr;
7855-
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7856-
bool IsDefault = SI->getDefaultDest() == Dst;
7857-
for (auto &C : SI->cases()) {
7858-
if (IsDefault) {
7859-
if (C.getCaseSuccessor() == Dst)
7860-
continue;
7861-
} else if (C.getCaseSuccessor() != Dst)
7891+
VPValue *DefaultMask = nullptr;
7892+
for (const auto &[Dst, Conds] : Map) {
7893+
if (Dst == DefaultDst)
7894+
continue;
7895+
if (!DefaultMask) {
7896+
DefaultMask = EdgeMaskCache[{Src, Dst}];
78627897
continue;
7863-
7864-
VPValue *Eq = EdgeMaskCache.lookup({Src, C.getCaseSuccessor()});
7865-
if (!Eq) {
7866-
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7867-
Eq = Builder.createICmp(CmpInst::ICMP_EQ, Cond, V);
78687898
}
7869-
if (Mask)
7870-
Mask = Builder.createOr(Mask, Eq);
7871-
else
7872-
Mask = Eq;
7899+
DefaultMask = Builder.createOr(DefaultMask, EdgeMaskCache[{Src, Dst}]);
7900+
}
7901+
if (DefaultMask) {
7902+
DefaultMask = Builder.createNot(DefaultMask);
7903+
if (SrcMask)
7904+
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
78737905
}
7874-
if (IsDefault)
7875-
Mask = Builder.createNot(Mask);
7876-
assert(Mask && "mask must be created");
7877-
return EdgeMaskCache[Edge] = Mask;
7906+
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7907+
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
7908+
return EdgeMaskCache[Edge];
78787909
}
78797910

78807911
// The terminator has to be a branch inst!

0 commit comments

Comments
 (0)