Skip to content

Commit 47258de

Browse files
committed
[VPlan] Dispatch to multiple exit blocks via middle blocks.
A more lightweight variant of #109193, which dispatches to multiple exit blocks via the middle blocks.
1 parent 245b56a commit 47258de

File tree

12 files changed

+614
-56
lines changed

12 files changed

+614
-56
lines changed

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ class LoopVectorizationLegality {
287287
/// we can use in-order reductions.
288288
bool canVectorizeFPMath(bool EnableStrictReductions);
289289

290+
/// Returns true if the loop has an early exit that we can vectorize.
291+
bool canVectorizeEarlyExit() const;
292+
290293
/// Return true if we can vectorize this loop while folding its tail by
291294
/// masking.
292295
bool canFoldTailByMasking() const;

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
4343
cl::desc("Enable recognition of non-constant strided "
4444
"pointer induction variables."));
4545

46+
static cl::opt<bool>
47+
EnableEarlyExitVectorization("enable-early-exit-vectorization",
48+
cl::init(false), cl::Hidden, cl::desc(""));
49+
4650
namespace llvm {
4751
cl::opt<bool>
4852
HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
@@ -1378,6 +1382,10 @@ bool LoopVectorizationLegality::isFixedOrderRecurrence(
13781382
}
13791383

13801384
bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
1385+
// When vectorizing early exits, create predicates for all blocks, except the
1386+
// header.
1387+
if (canVectorizeEarlyExit() && BB != TheLoop->getHeader())
1388+
return true;
13811389
return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
13821390
}
13831391

@@ -1514,6 +1522,27 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
15141522
return true;
15151523
}
15161524

1525+
bool LoopVectorizationLegality::canVectorizeEarlyExit() const {
1526+
// Currently only allow vectorizing loops with early exits, if early-exit
1527+
// vectorization is explicitly enabled and the loop has metadata to force
1528+
// vectorization.
1529+
if (!EnableEarlyExitVectorization)
1530+
return false;
1531+
1532+
SmallVector<BasicBlock *> Exiting;
1533+
TheLoop->getExitingBlocks(Exiting);
1534+
if (Exiting.size() == 1)
1535+
return false;
1536+
1537+
LoopVectorizeHints Hints(TheLoop, true, *ORE);
1538+
if (Hints.getForce() == LoopVectorizeHints::FK_Undefined)
1539+
return false;
1540+
1541+
Function *Fn = TheLoop->getHeader()->getParent();
1542+
return Hints.allowVectorization(Fn, TheLoop,
1543+
true /*VectorizeOnlyWhenForced*/);
1544+
}
1545+
15171546
// Helper function to canVectorizeLoopNestCFG.
15181547
bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
15191548
bool UseVPlanNativePath) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,9 +1363,11 @@ class LoopVectorizationCostModel {
13631363
// If we might exit from anywhere but the latch, must run the exiting
13641364
// iteration in scalar form.
13651365
if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
1366-
LLVM_DEBUG(
1367-
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
1368-
return true;
1366+
if (!Legal->canVectorizeEarlyExit()) {
1367+
LLVM_DEBUG(
1368+
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
1369+
return true;
1370+
}
13691371
}
13701372
if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) {
13711373
LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: "
@@ -2575,7 +2577,8 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
25752577
LoopVectorPreHeader = OrigLoop->getLoopPreheader();
25762578
assert(LoopVectorPreHeader && "Invalid loop structure");
25772579
LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
2578-
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
2580+
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector()) ||
2581+
Legal->canVectorizeEarlyExit()) &&
25792582
"multiple exit loop without required epilogue?");
25802583

25812584
LoopMiddleBlock =
@@ -2758,8 +2761,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
27582761
// value (the value that feeds into the phi from the loop latch).
27592762
// We allow both, but they, obviously, have different values.
27602763

2761-
assert(OrigLoop->getUniqueExitBlock() && "Expected a single exit block");
2762-
27632764
DenseMap<Value *, Value *> MissingVals;
27642765

27652766
// An external user of the last iteration's value should see the value that
@@ -2819,6 +2820,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28192820
if (PHI->getBasicBlockIndex(MiddleBlock) == -1)
28202821
PHI->addIncoming(I.second, MiddleBlock);
28212822
}
2823+
2824+
assert((MissingVals.empty() || OrigLoop->getUniqueExitBlock()) &&
2825+
"Expected a single exit block");
28222826
}
28232827

28242828
namespace {
@@ -3599,7 +3603,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
35993603
TheLoop->getExitingBlocks(Exiting);
36003604
for (BasicBlock *E : Exiting) {
36013605
auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0));
3602-
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
3606+
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse() &&
3607+
(TheLoop->getLoopLatch() == E || !Legal->canVectorizeEarlyExit()))
36033608
AddToWorklistIfAllowed(Cmp);
36043609
}
36053610

@@ -7692,12 +7697,15 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76927697
BestVPlan.execute(&State);
76937698

76947699
// 2.5 Collect reduction resume values.
7695-
auto *ExitVPBB =
7696-
cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
7697-
for (VPRecipeBase &R : *ExitVPBB) {
7698-
createAndCollectMergePhiForReduction(
7699-
dyn_cast<VPInstruction>(&R), State, OrigLoop,
7700-
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
7700+
VPBasicBlock *ExitVPBB = nullptr;
7701+
if (BestVPlan.getVectorLoopRegion()->getSingleSuccessor()) {
7702+
ExitVPBB = cast<VPBasicBlock>(
7703+
BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
7704+
for (VPRecipeBase &R : *ExitVPBB) {
7705+
createAndCollectMergePhiForReduction(
7706+
dyn_cast<VPInstruction>(&R), State, OrigLoop,
7707+
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
7708+
}
77017709
}
77027710

77037711
// 2.6. Maintain Loop Hints
@@ -7723,6 +7731,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
77237731
LoopVectorizeHints Hints(L, true, *ORE);
77247732
Hints.setAlreadyVectorized();
77257733
}
7734+
77267735
TargetTransformInfo::UnrollingPreferences UP;
77277736
TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
77287737
if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
@@ -7735,15 +7744,17 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
77357744
ILV.printDebugTracesAtEnd();
77367745

77377746
// 4. Adjust branch weight of the branch in the middle block.
7738-
auto *MiddleTerm =
7739-
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
7740-
if (MiddleTerm->isConditional() &&
7741-
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7742-
// Assume that `Count % VectorTripCount` is equally distributed.
7743-
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7744-
assert(TripCount > 0 && "trip count should not be zero");
7745-
const uint32_t Weights[] = {1, TripCount - 1};
7746-
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7747+
if (ExitVPBB) {
7748+
auto *MiddleTerm =
7749+
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
7750+
if (MiddleTerm->isConditional() &&
7751+
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7752+
// Assume that `Count % VectorTripCount` is equally distributed.
7753+
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7754+
assert(TripCount > 0 && "trip count should not be zero");
7755+
const uint32_t Weights[] = {1, TripCount - 1};
7756+
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7757+
}
77477758
}
77487759

77497760
return State.ExpandedSCEVs;
@@ -8128,7 +8139,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
81288139
// If source is an exiting block, we know the exit edge is dynamically dead
81298140
// in the vector loop, and thus we don't need to restrict the mask. Avoid
81308141
// adding uses of an otherwise potentially dead instruction.
8131-
if (OrigLoop->isLoopExiting(Src))
8142+
if (!Legal->canVectorizeEarlyExit() && OrigLoop->isLoopExiting(Src))
81328143
return EdgeMaskCache[Edge] = SrcMask;
81338144

81348145
VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
@@ -8778,6 +8789,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
87788789
static SetVector<VPIRInstruction *> collectUsersInExitBlock(
87798790
Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
87808791
const MapVector<PHINode *, InductionDescriptor> &Inductions) {
8792+
if (!Plan.getVectorLoopRegion()->getSingleSuccessor())
8793+
return {};
87818794
auto *MiddleVPBB =
87828795
cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
87838796
// No edge from the middle block to the unique exit block has been inserted
@@ -8863,6 +8876,8 @@ static void addLiveOutsForFirstOrderRecurrences(
88638876
// TODO: Should be replaced by
88648877
// Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the
88658878
// scalar region is modeled as well.
8879+
if (!VectorRegion->getSingleSuccessor())
8880+
return;
88668881
auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
88678882
VPBasicBlock *ScalarPHVPBB = nullptr;
88688883
if (MiddleVPBB->getNumSuccessors() == 2) {
@@ -9146,10 +9161,15 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
91469161
"VPBasicBlock");
91479162
RecipeBuilder.fixHeaderPhis();
91489163

9149-
SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
9150-
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
9151-
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
9152-
addUsersInExitBlock(*Plan, ExitUsersToFix);
9164+
if (Legal->canVectorizeEarlyExit()) {
9165+
VPlanTransforms::convertToMultiCond(*Plan, *PSE.getSE(), OrigLoop,
9166+
RecipeBuilder);
9167+
} else {
9168+
SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
9169+
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
9170+
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
9171+
addUsersInExitBlock(*Plan, ExitUsersToFix);
9172+
}
91539173

91549174
// ---------------------------------------------------------------------------
91559175
// Transform initial VPlan: Apply previously taken decisions, in order, to
@@ -9277,8 +9297,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92779297
using namespace VPlanPatternMatch;
92789298
VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
92799299
VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
9280-
VPBasicBlock *MiddleVPBB =
9281-
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
92829300
for (VPRecipeBase &R : Header->phis()) {
92839301
auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
92849302
if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
@@ -9297,8 +9315,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92979315
for (VPUser *U : Cur->users()) {
92989316
auto *UserRecipe = cast<VPSingleDefRecipe>(U);
92999317
if (!UserRecipe->getParent()->getEnclosingLoopRegion()) {
9300-
assert(UserRecipe->getParent() == MiddleVPBB &&
9301-
"U must be either in the loop region or the middle block.");
93029318
continue;
93039319
}
93049320
Worklist.insert(UserRecipe);
@@ -9403,6 +9419,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
94039419
}
94049420
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
94059421
Builder.setInsertPoint(&*LatchVPBB->begin());
9422+
if (!VectorLoopRegion->getSingleSuccessor())
9423+
return;
9424+
VPBasicBlock *MiddleVPBB =
9425+
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
94069426
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
94079427
for (VPRecipeBase &R :
94089428
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,14 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
474474
// backedges. A backward successor is set when the branch is created.
475475
const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
476476
unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
477+
if (TermBr->getSuccessor(idx) &&
478+
PredVPBlock == getPlan()->getVectorLoopRegion() &&
479+
PredVPBlock->getNumSuccessors()) {
480+
// Update PRedBB and TermBr for BranchOnMultiCond in predecessor.
481+
PredBB = TermBr->getSuccessor(1);
482+
TermBr = cast<BranchInst>(PredBB->getTerminator());
483+
idx = 0;
484+
}
477485
assert(!TermBr->getSuccessor(idx) &&
478486
"Trying to reset an existing successor block.");
479487
TermBr->setSuccessor(idx, IRBB);
@@ -908,8 +916,8 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
908916
VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
909917
VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
910918

911-
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
912919
if (!RequiresScalarEpilogueCheck) {
920+
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
913921
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
914922
return Plan;
915923
}
@@ -923,10 +931,14 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
923931
// we unconditionally branch to the scalar preheader. Do nothing.
924932
// 3) Otherwise, construct a runtime check.
925933
BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
926-
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
927-
// The connection order corresponds to the operands of the conditional branch.
928-
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
929-
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
934+
if (IRExitBlock) {
935+
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
936+
// The connection order corresponds to the operands of the conditional
937+
// branch.
938+
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
939+
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
940+
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
941+
}
930942

931943
auto *ScalarLatchTerm = TheLoop->getLoopLatch()->getTerminator();
932944
// Here we use the same DebugLoc as the scalar loop latch terminator instead
@@ -1031,7 +1043,9 @@ void VPlan::execute(VPTransformState *State) {
10311043
// VPlan execution rather than earlier during VPlan construction.
10321044
BasicBlock *MiddleBB = State->CFG.ExitBB;
10331045
VPBasicBlock *MiddleVPBB =
1034-
cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
1046+
getVectorLoopRegion()->getNumSuccessors() == 1
1047+
? cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[0])
1048+
: cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[1]);
10351049
// Find the VPBB for the scalar preheader, relying on the current structure
10361050
// when creating the middle block and its successrs: if there's a single
10371051
// predecessor, it must be the scalar preheader. Otherwise, the second
@@ -1044,6 +1058,10 @@ void VPlan::execute(VPTransformState *State) {
10441058
MiddleSuccs.size() == 1 ? MiddleSuccs[0] : MiddleSuccs[1]);
10451059
assert(!isa<VPIRBasicBlock>(ScalarPhVPBB) &&
10461060
"scalar preheader cannot be wrapped already");
1061+
if (ScalarPhVPBB->getNumSuccessors() != 0) {
1062+
ScalarPhVPBB = cast<VPBasicBlock>(ScalarPhVPBB->getSuccessors()[1]);
1063+
MiddleVPBB = cast<VPBasicBlock>(MiddleVPBB->getSuccessors()[1]);
1064+
}
10471065
replaceVPBBWithIRVPBB(ScalarPhVPBB, ScalarPh);
10481066
replaceVPBBWithIRVPBB(MiddleVPBB, MiddleBB);
10491067

@@ -1065,6 +1083,10 @@ void VPlan::execute(VPTransformState *State) {
10651083
VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock();
10661084
BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB];
10671085

1086+
if (!getVectorLoopRegion()->getSingleSuccessor())
1087+
VectorLatchBB =
1088+
cast<BranchInst>(VectorLatchBB->getTerminator())->getSuccessor(1);
1089+
10681090
// Fix the latch value of canonical, reduction and first-order recurrences
10691091
// phis in the vector loop.
10701092
VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock();
@@ -1091,7 +1113,10 @@ void VPlan::execute(VPTransformState *State) {
10911113
// Move the last step to the end of the latch block. This ensures
10921114
// consistent placement of all induction updates.
10931115
Instruction *Inc = cast<Instruction>(Phi->getIncomingValue(1));
1094-
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
1116+
if (VectorLatchBB->getTerminator() == &*VectorLatchBB->getFirstNonPHI())
1117+
Inc->moveBefore(VectorLatchBB->getTerminator());
1118+
else
1119+
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
10951120

10961121
// Use the steps for the last part as backedge value for the induction.
10971122
if (auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R))

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
12741274
// operand). Only generates scalar values (either for the first lane only or
12751275
// for all lanes, depending on its uses).
12761276
PtrAdd,
1277+
AnyOf,
12771278
};
12781279

12791280
private:

0 commit comments

Comments
 (0)