Skip to content

Commit b34d919

Browse files
committed
!fixup address latest comments, thanks!
1 parent 66f27dd commit b34d919

File tree

2 files changed

+61
-53
lines changed

2 files changed

+61
-53
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7850,6 +7850,60 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
78507850
return map_range(Operands, Fn);
78517851
}
78527852

7853+
void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
7854+
BasicBlock *Src = SI->getParent();
7855+
assert(!OrigLoop->isLoopExiting(Src) &&
7856+
all_of(successors(Src),
7857+
[this](BasicBlock *Succ) {
7858+
return OrigLoop->getHeader() != Succ;
7859+
}) &&
7860+
"unsupported switch either exiting loop or continuing to header");
7861+
// Create masks where the terminator in Src is a switch. We create mask for
7862+
// all edges at the same time. This is more efficient, as we can create and
7863+
// collect compares for all cases once.
7864+
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7865+
BasicBlock *DefaultDst = SI->getDefaultDest();
7866+
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
7867+
for (auto &C : SI->cases()) {
7868+
// Cases whose destination is the same as default are redundant and can be
7869+
// ignored - they will get there anyhow.
7870+
if (C.getCaseSuccessor() == DefaultDst)
7871+
continue;
7872+
auto I = Dst2Compares.insert({C.getCaseSuccessor(), {}});
7873+
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7874+
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7875+
}
7876+
7877+
// We need to handle 2 separate cases below for all entries in Dst2Compares,
7878+
// which excludes destinations matching the default destination.
7879+
VPValue *SrcMask = getBlockInMask(Src);
7880+
VPValue *DefaultMask = nullptr;
7881+
for (const auto &[Dst, Conds] : Dst2Compares) {
7882+
// 1. Dst is not the default destination. Dst is reached if any of the cases
7883+
// with destination == Dst are taken. Join the conditions for each case
7884+
// whose destination == Dst using an OR.
7885+
VPValue *Mask = Conds[0];
7886+
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7887+
Mask = Builder.createOr(Mask, V);
7888+
if (SrcMask)
7889+
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7890+
EdgeMaskCache[{Src, Dst}] = Mask;
7891+
7892+
// 2. Create the mask for the default destination, which is reached if none
7893+
// of the cases with destination != default destination are taken. Join the
7894+
// conditions for each case where the destination is != Dst using an OR and
7895+
// negate it.
7896+
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
7897+
}
7898+
7899+
if (DefaultMask) {
7900+
DefaultMask = Builder.createNot(DefaultMask);
7901+
if (SrcMask)
7902+
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
7903+
}
7904+
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7905+
}
7906+
78537907
VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78547908
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
78557909

@@ -7859,67 +7913,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78597913
if (ECEntryIt != EdgeMaskCache.end())
78607914
return ECEntryIt->second;
78617915

7862-
VPValue *SrcMask = getBlockInMask(Src);
7863-
78647916
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
7865-
assert(!OrigLoop->isLoopExiting(Src) &&
7866-
all_of(successors(Src),
7867-
[this](BasicBlock *Succ) {
7868-
return OrigLoop->getHeader() != Succ;
7869-
}) &&
7870-
"unsupported switch either exiting loop or continuing to header");
7871-
// Create masks where the terminator in Src is a switch. We create mask for
7872-
// all edges at the same time. This is more efficient, as we can create and
7873-
// collect compares for all cases once.
7874-
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7875-
BasicBlock *DefaultDst = SI->getDefaultDest();
7876-
MapVector<BasicBlock *, SmallVector<VPValue *>> Map;
7877-
for (auto &C : SI->cases()) {
7878-
auto I = Map.insert({C.getCaseSuccessor(), {}});
7879-
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7880-
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7881-
}
7882-
7883-
// We need to handle 2 separate cases:
7884-
// 1. Dst is not the default destination. Dst is reached if any of the cases
7885-
// with destination == Dst are taken. Join the conditions for each case
7886-
// where destination == Dst using a logical OR.
7887-
for (const auto &[Dst, Conds] : Map) {
7888-
VPValue *Mask = Conds[0];
7889-
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7890-
Mask = Builder.createOr(Mask, V);
7891-
if (SrcMask)
7892-
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7893-
EdgeMaskCache[{Src, Dst}] = Mask;
7894-
}
7895-
7896-
// 2. Dst is the default destination. Dst is reached if none of the cases
7897-
// with destination != Dst are taken. Join the conditions for each case
7898-
// where the destination is != Dst using a logical OR and negate it.
7899-
VPValue *DefaultMask = nullptr;
7900-
for (const auto &[Dst, Conds] : Map) {
7901-
if (Dst == DefaultDst)
7902-
continue;
7903-
if (!DefaultMask) {
7904-
DefaultMask = EdgeMaskCache[{Src, Dst}];
7905-
continue;
7906-
}
7907-
DefaultMask = Builder.createOr(DefaultMask, EdgeMaskCache[{Src, Dst}]);
7908-
}
7909-
if (DefaultMask) {
7910-
DefaultMask = Builder.createNot(DefaultMask);
7911-
if (SrcMask)
7912-
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
7913-
}
7914-
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7917+
createSwitchEdgeMasks(SI);
79157918
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
79167919
return EdgeMaskCache[Edge];
79177920
}
79187921

7922+
VPValue *SrcMask = getBlockInMask(Src);
7923+
79197924
// The terminator has to be a branch inst!
79207925
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
79217926
assert(BI && "Unexpected terminator found");
7922-
79237927
if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
79247928
return EdgeMaskCache[Edge] = SrcMask;
79257929

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ class VPRecipeBuilder {
134134
/// Returns the *entry* mask for the block \p BB.
135135
VPValue *getBlockInMask(BasicBlock *BB) const;
136136

137+
/// Create masks for all cases with destination different than the default
138+
/// destination, and a mask for the default destination.
139+
void createSwitchEdgeMasks(SwitchInst *SI);
140+
137141
/// A helper function that computes the predicate of the edge between SRC
138142
/// and DST.
139143
VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);

0 commit comments

Comments
 (0)