Skip to content

Commit f0df4fb

Browse files
authored
[LV] Support generating masks for switch terminators. (#99808)
Update createEdgeMask to created masks where the terminator in Src is a switch. We need to handle 2 separate cases: 1. Dst is not the default desintation. Dst is reached if any of the cases with destination == Dst are taken. Join the conditions for each case where destination == Dst using a logical OR. 2. Dst is the default destination. Dst is reached if none of the cases with destination != Dst are taken. Join the conditions for each case where the destination is != Dst using a logical OR and negate it. Edge masks are created for every destination of cases and/or default when requesting a mask where the source is a switch. Fixes #48188. PR: #99808
1 parent d1bc41f commit f0df4fb

File tree

8 files changed

+1349
-52
lines changed

8 files changed

+1349
-52
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,12 +1340,21 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
13401340

13411341
// Collect the blocks that need predication.
13421342
for (BasicBlock *BB : TheLoop->blocks()) {
1343-
// We don't support switch statements inside loops.
1344-
if (!isa<BranchInst>(BB->getTerminator())) {
1345-
reportVectorizationFailure("Loop contains a switch statement",
1346-
"loop contains a switch statement",
1347-
"LoopContainsSwitch", ORE, TheLoop,
1348-
BB->getTerminator());
1343+
// We support only branches and switch statements as terminators inside the
1344+
// loop.
1345+
if (isa<SwitchInst>(BB->getTerminator())) {
1346+
if (TheLoop->isLoopExiting(BB)) {
1347+
reportVectorizationFailure("Loop contains an unsupported switch",
1348+
"loop contains an unsupported switch",
1349+
"LoopContainsUnsupportedSwitch", ORE,
1350+
TheLoop, BB->getTerminator());
1351+
return false;
1352+
}
1353+
} else if (!isa<BranchInst>(BB->getTerminator())) {
1354+
reportVectorizationFailure("Loop contains an unsupported terminator",
1355+
"loop contains an unsupported terminator",
1356+
"LoopContainsUnsupportedTerminator", ORE,
1357+
TheLoop, BB->getTerminator());
13491358
return false;
13501359
}
13511360

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6453,6 +6453,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
64536453
// a predicated block since it will become a fall-through, although we
64546454
// may decide in the future to call TTI for all branches.
64556455
}
6456+
case Instruction::Switch: {
6457+
if (VF.isScalar())
6458+
return TTI.getCFInstrCost(Instruction::Switch, CostKind);
6459+
auto *Switch = cast<SwitchInst>(I);
6460+
return Switch->getNumCases() *
6461+
TTI.getCmpSelInstrCost(
6462+
Instruction::ICmp,
6463+
ToVectorTy(Switch->getCondition()->getType(), VF),
6464+
ToVectorTy(Type::getInt1Ty(I->getContext()), VF),
6465+
CmpInst::ICMP_EQ, CostKind);
6466+
}
64566467
case Instruction::PHI: {
64576468
auto *Phi = cast<PHINode>(I);
64586469

@@ -7841,6 +7852,62 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
78417852
return map_range(Operands, Fn);
78427853
}
78437854

7855+
void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
7856+
BasicBlock *Src = SI->getParent();
7857+
assert(!OrigLoop->isLoopExiting(Src) &&
7858+
all_of(successors(Src),
7859+
[this](BasicBlock *Succ) {
7860+
return OrigLoop->getHeader() != Succ;
7861+
}) &&
7862+
"unsupported switch either exiting loop or continuing to header");
7863+
// Create masks where the terminator in Src is a switch. We create mask for
7864+
// all edges at the same time. This is more efficient, as we can create and
7865+
// collect compares for all cases once.
7866+
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7867+
BasicBlock *DefaultDst = SI->getDefaultDest();
7868+
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
7869+
for (auto &C : SI->cases()) {
7870+
BasicBlock *Dst = C.getCaseSuccessor();
7871+
assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created");
7872+
// Cases whose destination is the same as default are redundant and can be
7873+
// ignored - they will get there anyhow.
7874+
if (Dst == DefaultDst)
7875+
continue;
7876+
auto I = Dst2Compares.insert({Dst, {}});
7877+
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7878+
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7879+
}
7880+
7881+
// We need to handle 2 separate cases below for all entries in Dst2Compares,
7882+
// which excludes destinations matching the default destination.
7883+
VPValue *SrcMask = getBlockInMask(Src);
7884+
VPValue *DefaultMask = nullptr;
7885+
for (const auto &[Dst, Conds] : Dst2Compares) {
7886+
// 1. Dst is not the default destination. Dst is reached if any of the cases
7887+
// with destination == Dst are taken. Join the conditions for each case
7888+
// whose destination == Dst using an OR.
7889+
VPValue *Mask = Conds[0];
7890+
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7891+
Mask = Builder.createOr(Mask, V);
7892+
if (SrcMask)
7893+
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7894+
EdgeMaskCache[{Src, Dst}] = Mask;
7895+
7896+
// 2. Create the mask for the default destination, which is reached if none
7897+
// of the cases with destination != default destination are taken. Join the
7898+
// conditions for each case where the destination is != Dst using an OR and
7899+
// negate it.
7900+
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
7901+
}
7902+
7903+
if (DefaultMask) {
7904+
DefaultMask = Builder.createNot(DefaultMask);
7905+
if (SrcMask)
7906+
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
7907+
}
7908+
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7909+
}
7910+
78447911
VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78457912
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
78467913

@@ -7850,12 +7917,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78507917
if (ECEntryIt != EdgeMaskCache.end())
78517918
return ECEntryIt->second;
78527919

7920+
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
7921+
createSwitchEdgeMasks(SI);
7922+
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
7923+
return EdgeMaskCache[Edge];
7924+
}
7925+
78537926
VPValue *SrcMask = getBlockInMask(Src);
78547927

78557928
// The terminator has to be a branch inst!
78567929
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
78577930
assert(BI && "Unexpected terminator found");
7858-
78597931
if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
78607932
return EdgeMaskCache[Edge] = SrcMask;
78617933

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class VPRecipeBuilder {
134134
/// Returns the *entry* mask for the block \p BB.
135135
VPValue *getBlockInMask(BasicBlock *BB) const;
136136

137+
/// Create an edge mask for every destination of cases and/or default.
138+
void createSwitchEdgeMasks(SwitchInst *SI);
139+
137140
/// A helper function that computes the predicate of the edge between SRC
138141
/// and DST.
139142
VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);

0 commit comments

Comments
 (0)