Skip to content

Commit c0abd1b

Browse files
committed
!fixup address latest comments, thanks!
1 parent 40298c8 commit c0abd1b

File tree

2 files changed

+61
-53
lines changed

2 files changed

+61
-53
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7842,6 +7842,60 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
78427842
return map_range(Operands, Fn);
78437843
}
78447844

7845+
void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
7846+
BasicBlock *Src = SI->getParent();
7847+
assert(!OrigLoop->isLoopExiting(Src) &&
7848+
all_of(successors(Src),
7849+
[this](BasicBlock *Succ) {
7850+
return OrigLoop->getHeader() != Succ;
7851+
}) &&
7852+
"unsupported switch either exiting loop or continuing to header");
7853+
// Create masks where the terminator in Src is a switch. We create mask for
7854+
// all edges at the same time. This is more efficient, as we can create and
7855+
// collect compares for all cases once.
7856+
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7857+
BasicBlock *DefaultDst = SI->getDefaultDest();
7858+
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
7859+
for (auto &C : SI->cases()) {
7860+
// Cases whose destination is the same as default are redundant and can be
7861+
// ignored - they will get there anyhow.
7862+
if (C.getCaseSuccessor() == DefaultDst)
7863+
continue;
7864+
auto I = Dst2Compares.insert({C.getCaseSuccessor(), {}});
7865+
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7866+
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7867+
}
7868+
7869+
// We need to handle 2 separate cases below for all entries in Dst2Compares,
7870+
// which excludes destinations matching the default destination.
7871+
VPValue *SrcMask = getBlockInMask(Src);
7872+
VPValue *DefaultMask = nullptr;
7873+
for (const auto &[Dst, Conds] : Dst2Compares) {
7874+
// 1. Dst is not the default destination. Dst is reached if any of the cases
7875+
// with destination == Dst are taken. Join the conditions for each case
7876+
// whose destination == Dst using an OR.
7877+
VPValue *Mask = Conds[0];
7878+
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7879+
Mask = Builder.createOr(Mask, V);
7880+
if (SrcMask)
7881+
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7882+
EdgeMaskCache[{Src, Dst}] = Mask;
7883+
7884+
// 2. Create the mask for the default destination, which is reached if none
7885+
// of the cases with destination != default destination are taken. Join the
7886+
// conditions for each case where the destination is != Dst using an OR and
7887+
// negate it.
7888+
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
7889+
}
7890+
7891+
if (DefaultMask) {
7892+
DefaultMask = Builder.createNot(DefaultMask);
7893+
if (SrcMask)
7894+
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
7895+
}
7896+
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7897+
}
7898+
78457899
VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78467900
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");
78477901

@@ -7851,67 +7905,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
78517905
if (ECEntryIt != EdgeMaskCache.end())
78527906
return ECEntryIt->second;
78537907

7854-
VPValue *SrcMask = getBlockInMask(Src);
7855-
78567908
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
7857-
assert(!OrigLoop->isLoopExiting(Src) &&
7858-
all_of(successors(Src),
7859-
[this](BasicBlock *Succ) {
7860-
return OrigLoop->getHeader() != Succ;
7861-
}) &&
7862-
"unsupported switch either exiting loop or continuing to header");
7863-
// Create masks where the terminator in Src is a switch. We create mask for
7864-
// all edges at the same time. This is more efficient, as we can create and
7865-
// collect compares for all cases once.
7866-
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
7867-
BasicBlock *DefaultDst = SI->getDefaultDest();
7868-
MapVector<BasicBlock *, SmallVector<VPValue *>> Map;
7869-
for (auto &C : SI->cases()) {
7870-
auto I = Map.insert({C.getCaseSuccessor(), {}});
7871-
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
7872-
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
7873-
}
7874-
7875-
// We need to handle 2 separate cases:
7876-
// 1. Dst is not the default destination. Dst is reached if any of the cases
7877-
// with destination == Dst are taken. Join the conditions for each case
7878-
// where destination == Dst using a logical OR.
7879-
for (const auto &[Dst, Conds] : Map) {
7880-
VPValue *Mask = Conds[0];
7881-
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
7882-
Mask = Builder.createOr(Mask, V);
7883-
if (SrcMask)
7884-
Mask = Builder.createLogicalAnd(SrcMask, Mask);
7885-
EdgeMaskCache[{Src, Dst}] = Mask;
7886-
}
7887-
7888-
// 2. Dst is the default destination. Dst is reached if none of the cases
7889-
// with destination != Dst are taken. Join the conditions for each case
7890-
// where the destination is != Dst using a logical OR and negate it.
7891-
VPValue *DefaultMask = nullptr;
7892-
for (const auto &[Dst, Conds] : Map) {
7893-
if (Dst == DefaultDst)
7894-
continue;
7895-
if (!DefaultMask) {
7896-
DefaultMask = EdgeMaskCache[{Src, Dst}];
7897-
continue;
7898-
}
7899-
DefaultMask = Builder.createOr(DefaultMask, EdgeMaskCache[{Src, Dst}]);
7900-
}
7901-
if (DefaultMask) {
7902-
DefaultMask = Builder.createNot(DefaultMask);
7903-
if (SrcMask)
7904-
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
7905-
}
7906-
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
7909+
createSwitchEdgeMasks(SI);
79077910
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
79087911
return EdgeMaskCache[Edge];
79097912
}
79107913

7914+
VPValue *SrcMask = getBlockInMask(Src);
7915+
79117916
// The terminator has to be a branch inst!
79127917
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
79137918
assert(BI && "Unexpected terminator found");
7914-
79157919
if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
79167920
return EdgeMaskCache[Edge] = SrcMask;
79177921

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ class VPRecipeBuilder {
134134
/// Returns the *entry* mask for the block \p BB.
135135
VPValue *getBlockInMask(BasicBlock *BB) const;
136136

137+
/// Create masks for all cases with destination different than the default
138+
/// destination, and a mask for the default destination.
139+
void createSwitchEdgeMasks(SwitchInst *SI);
140+
137141
/// A helper function that computes the predicate of the edge between SRC
138142
/// and DST.
139143
VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);

0 commit comments

Comments
 (0)