Skip to content

Commit c6c4e77

Browse files
committed
[LoopVectorize] Enable vectorisation of early exit loops with live-outs
This work feeds part of PR #88385, and adds support for vectorising loops with uncountable early exits and outside users of loop-defined variables. When calculating the final value from an uncountable early exit we need to calculate the vector lane that triggered the exit, and hence determine the value at the point we exited. All code for calculating the last value when exiting the loop early now lives in a new vector.early.exit block, which sits between the middle.split block and the original exit block. Doing this required two fixes: 1. The vplan verifier incorrectly assumed that the block containing a definition always dominates the block of the user. That's not true if you can arrive at the use block from multiple incoming blocks. This is possible for early exit loops where both the early exit and the latch jump to the same block. I've added a new ExtractFirstActive VPInstruction that extracts the first active lane of a vector, i.e. the lane of the vector predicate that triggered the exit. NOTE: The IR generated for dealing with live-outs from early exit loops is unoptimised, as opposed to normal loops. This inevitably leads to poor quality code, but this can be fixed up later.
1 parent b4e17d4 commit c6c4e77

15 files changed

+1044
-167
lines changed

llvm/docs/Vectorizers.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,9 @@ Early Exit Vectorization
405405
When vectorizing a loop with a single early exit, the loop blocks following the
406406
early exit are predicated and the vector loop will always exit via the latch.
407407
If the early exit has been taken, the vector loop's successor block
408-
(``middle.split`` below) branches to the early exit block. Otherwise
409-
``middle.block`` selects between the exit block from the latch or the scalar
410-
remainder loop.
408+
(``middle.split`` below) branches to the early exit block via an intermediate
409+
block (``vector.early.exit`` below). Otherwise ``middle.block`` selects between
410+
the exit block from the latch or the scalar remainder loop.
411411

412412
.. image:: vplan-early-exit.png
413413

llvm/docs/vplan-early-exit.dot

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,27 @@ compound=true
1919
"middle.split"
2020
]
2121
N4 -> N5 [ label=""]
22-
N4 -> N6 [ label=""]
22+
N4 -> N7 [ label=""]
2323
N5 [label =
24-
"early.exit"
24+
"vector.early.exit"
2525
]
26+
N5 -> N6 [ label=""]
2627
N6 [label =
27-
"middle.block"
28+
"early.exit"
2829
]
29-
N6 -> N9 [ label=""]
30-
N6 -> N7 [ label=""]
3130
N7 [label =
32-
"scalar.ph"
31+
"middle.block"
3332
]
33+
N7 -> N10 [ label=""]
3434
N7 -> N8 [ label=""]
3535
N8 [label =
36-
"loop.header"
36+
"scalar.ph"
3737
]
38+
N8 -> N9 [ label=""]
3839
N9 [label =
40+
"loop.header"
41+
]
42+
N10 [label =
3943
"latch.exit"
4044
]
4145
}

llvm/docs/vplan-early-exit.png

-83.3 KB
Loading

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,11 @@ class LoopVectorizationLegality {
407407

408408
/// Returns the destination of an uncountable early exiting block.
409409
BasicBlock *getUncountableEarlyExitBlock() const {
410+
if (!HasUncountableEarlyExit) {
411+
assert(getUncountableExitBlocks().empty() &&
412+
"Expected no uncountable exiting blocks");
413+
return nullptr;
414+
}
410415
assert(getUncountableExitBlocks().size() == 1 &&
411416
"Expected only a single uncountable exit block");
412417
return getUncountableExitBlocks()[0];

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,6 +2775,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
27752775
return LoopVectorPreHeader;
27762776
}
27772777

2778+
static bool isValueIncomingFromBlock(BasicBlock *ExitingBB, Value *V,
2779+
Instruction *UI) {
2780+
PHINode *PHI = dyn_cast<PHINode>(UI);
2781+
assert(PHI && "Expected LCSSA form");
2782+
2783+
// If this loop has an uncountable early exit then there could be
2784+
// different users of OrigPhi with either:
2785+
// 1. Multiple users, because each exiting block (countable or
2786+
// uncountable) jumps to the same exit block, or ..
2787+
// 2. A single user with an incoming value from a countable or
2788+
// uncountable exiting block.
2789+
// In both cases there is no guarantee this came from a countable exiting
2790+
// block, i.e. the latch.
2791+
int Index = PHI->getBasicBlockIndex(ExitingBB);
2792+
return Index != -1 && PHI->getIncomingValue(Index) == V;
2793+
}
2794+
27782795
// Fix up external users of the induction variable. At this point, we are
27792796
// in LCSSA form, with all external PHIs that use the IV having one input value,
27802797
// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2797,12 +2814,13 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
27972814

27982815
// An external user of the last iteration's value should see the value that
27992816
// the remainder loop uses to initialize its own IV.
2800-
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch());
2817+
BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
2818+
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
28012819
for (User *U : PostInc->users()) {
28022820
Instruction *UI = cast<Instruction>(U);
28032821
if (!OrigLoop->contains(UI)) {
2804-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2805-
MissingVals[UI] = EndValue;
2822+
if (isValueIncomingFromBlock(OrigLoopLatch, PostInc, UI))
2823+
MissingVals[cast<PHINode>(UI)] = EndValue;
28062824
}
28072825
}
28082826

@@ -2812,7 +2830,8 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28122830
for (User *U : OrigPhi->users()) {
28132831
auto *UI = cast<Instruction>(U);
28142832
if (!OrigLoop->contains(UI)) {
2815-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2833+
if (!isValueIncomingFromBlock(OrigLoopLatch, OrigPhi, UI))
2834+
continue;
28162835
IRBuilder<> B(MiddleBlock->getTerminator());
28172836

28182837
// Fast-math-flags propagate from the original induction instruction.
@@ -2842,18 +2861,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28422861
}
28432862
}
28442863

2845-
assert((MissingVals.empty() ||
2846-
all_of(MissingVals,
2847-
[MiddleBlock, this](const std::pair<Value *, Value *> &P) {
2848-
return all_of(
2849-
predecessors(cast<Instruction>(P.first)->getParent()),
2850-
[MiddleBlock, this](BasicBlock *Pred) {
2851-
return Pred == MiddleBlock ||
2852-
Pred == OrigLoop->getLoopLatch();
2853-
});
2854-
})) &&
2855-
"Expected escaping values from latch/middle.block only");
2856-
28572864
for (auto &I : MissingVals) {
28582865
PHINode *PHI = cast<PHINode>(I.first);
28592866
// One corner case we have to handle is two IVs "chasing" each-other,
@@ -7774,6 +7781,9 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
77747781
State.LVer->prepareNoAliasMetadata();
77757782
}
77767783

7784+
// Set the uncountable early exit block in the VPTransformState.
7785+
State.CFG.UncountableEarlyExitBB = ILV.Legal->getUncountableEarlyExitBlock();
7786+
77777787
ILV.printDebugTracesAtStart();
77787788

77797789
//===------------------------------------------------===//
@@ -8958,6 +8968,9 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
89588968
// start value provides the value if the loop is bypassed.
89598969
bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
89608970
auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
8971+
assert(!Plan.getEarlyExit() &&
8972+
"Cannot handle reductions or first-order recurrences with "
8973+
"uncountable early exits");
89618974
if (IsFOR)
89628975
ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
89638976
VPInstruction::ExtractFromEnd, {ResumeFromVectorLoop, OneVPV}, {},
@@ -9075,14 +9088,20 @@ collectUsersInExitBlocks(Loop *OrigLoop, VPRecipeBuilder &Builder,
90759088
// Add exit values to \p Plan. Extracts are added for each entry in \p
90769089
// ExitUsersToFix if needed and their operands are updated. Returns true if all
90779090
// exit users can be handled, otherwise return false.
9078-
static bool
9091+
static void
90799092
addUsersInExitBlocks(VPlan &Plan,
90809093
const SetVector<VPIRInstruction *> &ExitUsersToFix) {
90819094
if (ExitUsersToFix.empty())
9082-
return true;
9095+
return;
90839096

90849097
auto *MiddleVPBB = Plan.getMiddleBlock();
9085-
VPBuilder B(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
9098+
VPBuilder MiddleB(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
9099+
VPBuilder EarlyExitB;
9100+
VPBasicBlock *VectorEarlyExitVPBB = Plan.getEarlyExit();
9101+
VPValue *EarlyExitMask = nullptr;
9102+
if (VectorEarlyExitVPBB)
9103+
EarlyExitB.setInsertPoint(VectorEarlyExitVPBB,
9104+
VectorEarlyExitVPBB->getFirstNonPhi());
90869105

90879106
// Introduce extract for exiting values and update the VPIRInstructions
90889107
// modeling the corresponding LCSSA phis.
@@ -9093,19 +9112,38 @@ addUsersInExitBlocks(VPlan &Plan,
90939112
if (Op->isLiveIn())
90949113
continue;
90959114

9096-
// Currently only live-ins can be used by exit values from blocks not
9097-
// exiting via the vector latch through to the middle block.
9098-
if (ExitIRI->getParent()->getSinglePredecessor() != MiddleVPBB)
9099-
return false;
9100-
91019115
LLVMContext &Ctx = ExitIRI->getInstruction().getContext();
9102-
VPValue *Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9103-
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9104-
IntegerType::get(Ctx, 32), 1))});
9116+
VPValue *Ext;
9117+
VPBasicBlock *PredVPBB =
9118+
cast<VPBasicBlock>(ExitIRI->getParent()->getPredecessors()[Idx]);
9119+
if (PredVPBB != MiddleVPBB) {
9120+
assert(ExitIRI->getParent()->getNumPredecessors() <= 2);
9121+
9122+
// Cache the early exit mask
9123+
if (!EarlyExitMask) {
9124+
VPBasicBlock *MiddleSplitVPBB =
9125+
cast<VPBasicBlock>(VectorEarlyExitVPBB->getSinglePredecessor());
9126+
VPInstruction *PredTerm =
9127+
cast<VPInstruction>(MiddleSplitVPBB->getTerminator());
9128+
assert(PredTerm->getOpcode() == VPInstruction::BranchOnCond &&
9129+
"Unexpected middle split block terminator");
9130+
VPInstruction *ScalarCond =
9131+
cast<VPInstruction>(PredTerm->getOperand(0));
9132+
assert(
9133+
ScalarCond->getOpcode() == VPInstruction::AnyOf &&
9134+
"Unexpected condition for middle split block terminator branch");
9135+
EarlyExitMask = ScalarCond->getOperand(0);
9136+
}
9137+
Ext = EarlyExitB.createNaryOp(VPInstruction::ExtractFirstActive,
9138+
{Op, EarlyExitMask});
9139+
} else {
9140+
Ext = MiddleB.createNaryOp(VPInstruction::ExtractFromEnd,
9141+
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9142+
IntegerType::get(Ctx, 32), 1))});
9143+
}
91059144
ExitIRI->setOperand(Idx, Ext);
91069145
}
91079146
}
9108-
return true;
91099147
}
91109148

91119149
/// Handle users in the exit block for first order reductions in the original
@@ -9401,12 +9439,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
94019439
SetVector<VPIRInstruction *> ExitUsersToFix =
94029440
collectUsersInExitBlocks(OrigLoop, RecipeBuilder, *Plan);
94039441
addExitUsersForFirstOrderRecurrences(*Plan, ExitUsersToFix);
9404-
if (!addUsersInExitBlocks(*Plan, ExitUsersToFix)) {
9405-
reportVectorizationFailure(
9406-
"Some exit values in loop with uncountable exit not supported yet",
9407-
"UncountableEarlyExitLoopsUnsupportedExitValue", ORE, OrigLoop);
9408-
return nullptr;
9409-
}
9442+
addUsersInExitBlocks(*Plan, ExitUsersToFix);
94109443

94119444
// ---------------------------------------------------------------------------
94129445
// Transform initial VPlan: Apply previously taken decisions, in order, to

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,15 @@ void VPBasicBlock::execute(VPTransformState *State) {
501501
UnreachableInst *Terminator = State->Builder.CreateUnreachable();
502502
// Register NewBB in its loop. In innermost loops its the same for all
503503
// BB's.
504-
if (State->CurrentParentLoop)
504+
if (this == State->Plan->getEarlyExit()) {
505+
// If this is the vector early exit block then it has a single successor,
506+
// which is the uncountable early exit block of the original loop. The
507+
// parent loop for the exit block may not be the same as the parent loop
508+
// of the vectorised loop, so we have to treat this differently.
509+
Loop *EEL = State->LI->getLoopFor(State->CFG.UncountableEarlyExitBB);
510+
if (EEL)
511+
EEL->addBasicBlockToLoop(NewBB, *State->LI);
512+
} else if (State->CurrentParentLoop)
505513
State->CurrentParentLoop->addBasicBlockToLoop(NewBB, *State->LI);
506514
State->Builder.SetInsertPoint(Terminator);
507515

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ struct VPTransformState {
347347
/// vector loop.
348348
BasicBlock *ExitBB = nullptr;
349349

350+
/// The uncountable early exit block in the original scalar loop.
351+
BasicBlock *UncountableEarlyExitBB = nullptr;
352+
350353
/// A mapping of each VPBasicBlock to the corresponding BasicBlock. In case
351354
/// of replication, maps the BasicBlock of the last replica created.
352355
SmallDenseMap<VPBasicBlock *, BasicBlock *> VPBB2IRBB;
@@ -1225,6 +1228,9 @@ class VPInstruction : public VPRecipeWithIRFlags,
12251228
// Returns a scalar boolean value, which is true if any lane of its single
12261229
// operand is true.
12271230
AnyOf,
1231+
// Extracts the first active lane of a vector, where the first operand is
1232+
// the predicate, and the second operand is the vector to extract.
1233+
ExtractFirstActive,
12281234
};
12291235

12301236
private:
@@ -3871,6 +3877,22 @@ class VPlan {
38713877
VPRegionBlock *getVectorLoopRegion();
38723878
const VPRegionBlock *getVectorLoopRegion() const;
38733879

3880+
/// Get the vector early exit block
3881+
VPBasicBlock *getEarlyExit() {
3882+
auto LoopRegion = getVectorLoopRegion();
3883+
if (!LoopRegion)
3884+
return nullptr;
3885+
3886+
auto *SuccessorVPBB = LoopRegion->getSingleSuccessor();
3887+
auto *MiddleVPBB = getMiddleBlock();
3888+
if (SuccessorVPBB == MiddleVPBB)
3889+
return nullptr;
3890+
3891+
assert(SuccessorVPBB->getSuccessors()[1] == MiddleVPBB &&
3892+
"Expected second successor to be the middle block");
3893+
return cast<VPBasicBlock>(SuccessorVPBB->getSuccessors()[0]);
3894+
}
3895+
38743896
/// Returns the 'middle' block of the plan, that is the block that selects
38753897
/// whether to execute the scalar tail loop or the exit block from the loop
38763898
/// latch.

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,14 +630,21 @@ Value *VPInstruction::generate(VPTransformState &State) {
630630
Value *A = State.get(getOperand(0));
631631
return Builder.CreateOrReduce(A);
632632
}
633-
633+
case VPInstruction::ExtractFirstActive: {
634+
Value *Vec = State.get(getOperand(0));
635+
Value *Mask = State.get(getOperand(1));
636+
Value *Ctz =
637+
Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(), Mask);
638+
return Builder.CreateExtractElement(Vec, Ctz);
639+
}
634640
default:
635641
llvm_unreachable("Unsupported opcode for instruction");
636642
}
637643
}
638644

639645
bool VPInstruction::isVectorToScalar() const {
640646
return getOpcode() == VPInstruction::ExtractFromEnd ||
647+
getOpcode() == VPInstruction::ExtractFirstActive ||
641648
getOpcode() == VPInstruction::ComputeReductionResult ||
642649
getOpcode() == VPInstruction::AnyOf;
643650
}
@@ -702,6 +709,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
702709
case VPInstruction::CalculateTripCountMinusVF:
703710
case VPInstruction::CanonicalIVIncrementForPart:
704711
case VPInstruction::ExtractFromEnd:
712+
case VPInstruction::ExtractFirstActive:
705713
case VPInstruction::FirstOrderRecurrenceSplice:
706714
case VPInstruction::LogicalAnd:
707715
case VPInstruction::Not:
@@ -821,6 +829,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
821829
case VPInstruction::AnyOf:
822830
O << "any-of";
823831
break;
832+
case VPInstruction::ExtractFirstActive:
833+
O << "extract-first-active";
834+
break;
824835
default:
825836
O << Instruction::getOpcodeName(getOpcode());
826837
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1973,10 +1973,13 @@ void VPlanTransforms::handleUncountableEarlyExit(
19731973
Builder.createNaryOp(VPInstruction::AnyOf, {EarlyExitTakenCond});
19741974

19751975
VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
1976+
VPBasicBlock *EarlyExitVPBB = Plan.createVPBasicBlock("vector.early.exit");
19761977
VPBlockUtils::insertOnEdge(LoopRegion, MiddleVPBB, NewMiddle);
1977-
VPBlockUtils::connectBlocks(NewMiddle, VPEarlyExitBlock);
1978+
VPBlockUtils::connectBlocks(NewMiddle, EarlyExitVPBB);
19781979
NewMiddle->swapSuccessors();
19791980

1981+
VPBlockUtils::connectBlocks(EarlyExitVPBB, VPEarlyExitBlock);
1982+
19801983
VPBuilder MiddleBuilder(NewMiddle);
19811984
MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});
19821985

llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,11 @@ bool VPlanVerifier::verifyVPBasicBlock(const VPBasicBlock *VPBB) {
212212
continue;
213213
}
214214

215-
if (!VPDT.dominates(VPBB, UI->getParent())) {
215+
// Now that we support vectorising loops with uncountable early exits
216+
// we can end up in situations where VPBB does not dominate the exit
217+
// block. Only do the check if the user is not in a VPIRBasicBlock.
218+
if (!isa<VPIRBasicBlock>(UI->getParent()) &&
219+
!VPDT.dominates(VPBB, UI->getParent())) {
216220
errs() << "Use before def!\n";
217221
return false;
218222
}

0 commit comments

Comments
 (0)