Skip to content

Commit 8c28e3f

Browse files
committed
[LoopVectorize] Enable vectorisation of early exit loops with live-outs
This work feeds part of PR #88385, and adds support for vectorising loops with uncountable early exits and outside users of loop-defined variables. I've added a new fixupEarlyExitIVUsers to mirror what happens in fixupIVUsers when patching up outside users of induction variables in the early exit block. We have to handle these differently for two reasons: 1. We can't work backwards from the end value in the middle block because we didn't leave at the last iteration. 2. We need to generate different IR that calculates the vector lane that triggered the exit, and hence can determine the induction value at the point we exited. I've added a new 'null' VPValue as a dummy placeholder to manage the incoming operands of PHI nodes in the exit block. We can have situations where one of the incoming values is an induction variable (or its update) and the other is not. For example, both the latch and the early exiting block can jump to the same exit block. However, VPInstruction::generate walks through all predecessors of the PHI assuming the value is *not* an IV. In order to ensure that we process the right value for the right incoming block we use this new 'null' value is a marker to indicate it should be skipped, since it will be handled separately in fixupIVUsers or fixupEarlyExitIVUsers. All code for calculating the last value when exiting the loop early now lives in a new vector.early.exit block, which sits between the middle.split block and the original exit block. I also had to fix up the vplan verifier because it assumed that the block containing a definition always dominated the parent of the user. That's no longer the case because we can arrive at the exit block via one of the latch or the early exiting block. I've added a new ExtractFirstActive VPInstruction that extracts the first active lane of a vector, i.e. the lane of the vector predicate that triggered the exit.
1 parent beea5ac commit 8c28e3f

14 files changed

+1176
-198
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 164 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,11 @@ class InnerLoopVectorizer {
563563
Value *VectorTripCount, BasicBlock *MiddleBlock,
564564
VPTransformState &State);
565565

566+
void fixupEarlyExitIVUsers(PHINode *OrigPhi, const InductionDescriptor &II,
567+
BasicBlock *VectorEarlyExitBB,
568+
BasicBlock *MiddleBlock, VPlan &Plan,
569+
VPTransformState &State);
570+
566571
/// Iteratively sink the scalarized operands of a predicated instruction into
567572
/// the block that was created for it.
568573
void sinkScalarOperands(Instruction *PredInst);
@@ -2838,6 +2843,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
28382843
return LoopVectorPreHeader;
28392844
}
28402845

2846+
static bool isValueIncomingFromBlock(BasicBlock *ExitingBB, Value *V,
2847+
Instruction *UI) {
2848+
PHINode *PHI = dyn_cast<PHINode>(UI);
2849+
assert(PHI && "Expected LCSSA form");
2850+
2851+
// If this loop has an uncountable early exit then there could be
2852+
// different users of OrigPhi with either:
2853+
// 1. Multiple users, because each exiting block (countable or
2854+
// uncountable) jumps to the same exit block, or ..
2855+
// 2. A single user with an incoming value from a countable or
2856+
// uncountable exiting block.
2857+
// In both cases there is no guarantee this came from a countable exiting
2858+
// block, i.e. the latch.
2859+
int Index = PHI->getBasicBlockIndex(ExitingBB);
2860+
return Index != -1 && PHI->getIncomingValue(Index) == V;
2861+
}
2862+
28412863
// Fix up external users of the induction variable. At this point, we are
28422864
// in LCSSA form, with all external PHIs that use the IV having one input value,
28432865
// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2853,19 +2875,20 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28532875
// We allow both, but they, obviously, have different values.
28542876

28552877
DenseMap<Value *, Value *> MissingVals;
2878+
BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
28562879

28572880
Value *EndValue = cast<PHINode>(OrigPhi->getIncomingValueForBlock(
28582881
OrigLoop->getLoopPreheader()))
28592882
->getIncomingValueForBlock(MiddleBlock);
28602883

28612884
// An external user of the last iteration's value should see the value that
28622885
// the remainder loop uses to initialize its own IV.
2863-
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch());
2886+
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
28642887
for (User *U : PostInc->users()) {
28652888
Instruction *UI = cast<Instruction>(U);
28662889
if (!OrigLoop->contains(UI)) {
2867-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2868-
MissingVals[UI] = EndValue;
2890+
if (isValueIncomingFromBlock(OrigLoopLatch, PostInc, UI))
2891+
MissingVals[cast<PHINode>(UI)] = EndValue;
28692892
}
28702893
}
28712894

@@ -2875,7 +2898,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28752898
for (User *U : OrigPhi->users()) {
28762899
auto *UI = cast<Instruction>(U);
28772900
if (!OrigLoop->contains(UI)) {
2878-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2901+
if (!isValueIncomingFromBlock(OrigLoopLatch, OrigPhi, UI))
2902+
continue;
2903+
28792904
IRBuilder<> B(MiddleBlock->getTerminator());
28802905

28812906
// Fast-math-flags propagate from the original induction instruction.
@@ -2905,18 +2930,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
29052930
}
29062931
}
29072932

2908-
assert((MissingVals.empty() ||
2909-
all_of(MissingVals,
2910-
[MiddleBlock, this](const std::pair<Value *, Value *> &P) {
2911-
return all_of(
2912-
predecessors(cast<Instruction>(P.first)->getParent()),
2913-
[MiddleBlock, this](BasicBlock *Pred) {
2914-
return Pred == MiddleBlock ||
2915-
Pred == OrigLoop->getLoopLatch();
2916-
});
2917-
})) &&
2918-
"Expected escaping values from latch/middle.block only");
2919-
29202933
for (auto &I : MissingVals) {
29212934
PHINode *PHI = cast<PHINode>(I.first);
29222935
// One corner case we have to handle is two IVs "chasing" each-other,
@@ -2929,6 +2942,102 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
29292942
}
29302943
}
29312944

2945+
void InnerLoopVectorizer::fixupEarlyExitIVUsers(PHINode *OrigPhi,
2946+
const InductionDescriptor &II,
2947+
BasicBlock *VectorEarlyExitBB,
2948+
BasicBlock *MiddleBlock,
2949+
VPlan &Plan,
2950+
VPTransformState &State) {
2951+
// There are two kinds of external IV usages - those that use the value
2952+
// computed in the last iteration (the PHI) and those that use the penultimate
2953+
// value (the value that feeds into the phi from the loop latch).
2954+
// We allow both, but they, obviously, have different values.
2955+
DenseMap<Value *, Value *> MissingVals;
2956+
BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
2957+
BasicBlock *EarlyExitingBB = Legal->getUncountableEarlyExitingBlock();
2958+
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
2959+
2960+
// Obtain the canonical IV, since we have to use the most recent value
2961+
// before exiting the loop early. This is unlike fixupIVUsers, which has
2962+
// the luxury of using the end value in the middle block.
2963+
VPBasicBlock *EntryVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
2964+
// NOTE: We cannot call Plan.getCanonicalIV() here because the original
2965+
// recipe created whilst building plans is no longer valid.
2966+
VPHeaderPHIRecipe *CanonicalIVR =
2967+
cast<VPHeaderPHIRecipe>(&*EntryVPBB->begin());
2968+
Value *CanonicalIV = State.get(CanonicalIVR->getVPSingleValue(), true);
2969+
2970+
// Search for the mask that drove us to exit early.
2971+
VPBasicBlock *EarlyExitVPBB = Plan.getVectorLoopRegion()->getEarlyExit();
2972+
VPBasicBlock *MiddleSplitVPBB =
2973+
cast<VPBasicBlock>(EarlyExitVPBB->getSinglePredecessor());
2974+
VPInstruction *BranchOnCond =
2975+
cast<VPInstruction>(MiddleSplitVPBB->getTerminator());
2976+
assert(BranchOnCond->getOpcode() == VPInstruction::BranchOnCond &&
2977+
"Expected middle.split block terminator to be a branch-on-cond");
2978+
VPInstruction *ScalarEarlyExitCond =
2979+
cast<VPInstruction>(BranchOnCond->getOperand(0));
2980+
assert(
2981+
ScalarEarlyExitCond->getOpcode() == VPInstruction::AnyOf &&
2982+
"Expected middle.split block terminator branch condition to be any-of");
2983+
VPValue *VectorEarlyExitCond = ScalarEarlyExitCond->getOperand(0);
2984+
// Finally get the mask that led us into the early exit block.
2985+
Value *EarlyExitMask = State.get(VectorEarlyExitCond);
2986+
2987+
// Calculate the IV step.
2988+
VPValue *StepVPV = Plan.getSCEVExpansion(II.getStep());
2989+
assert(StepVPV && "step must have been expanded during VPlan execution");
2990+
Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue()
2991+
: State.get(StepVPV, VPLane(0));
2992+
2993+
auto FixUpPhi = [&](Instruction *UI, bool PostInc) -> Value * {
2994+
IRBuilder<> B(VectorEarlyExitBB->getTerminator());
2995+
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2996+
2997+
// Fast-math-flags propagate from the original induction instruction.
2998+
if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp()))
2999+
B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags());
3000+
3001+
Type *CtzType = CanonicalIV->getType();
3002+
Value *Ctz = B.CreateCountTrailingZeroElems(CtzType, EarlyExitMask);
3003+
Ctz = B.CreateAdd(Ctz, cast<PHINode>(CanonicalIV));
3004+
if (PostInc)
3005+
Ctz = B.CreateAdd(Ctz, ConstantInt::get(CtzType, 1));
3006+
3007+
Value *Escape = emitTransformedIndex(B, Ctz, II.getStartValue(), Step,
3008+
II.getKind(), II.getInductionBinOp());
3009+
Escape->setName("ind.early.escape");
3010+
return Escape;
3011+
};
3012+
3013+
for (User *U : PostInc->users()) {
3014+
auto *UI = cast<Instruction>(U);
3015+
if (!OrigLoop->contains(UI)) {
3016+
if (isValueIncomingFromBlock(EarlyExitingBB, PostInc, UI))
3017+
MissingVals[UI] = FixUpPhi(UI, true);
3018+
}
3019+
}
3020+
3021+
for (User *U : OrigPhi->users()) {
3022+
auto *UI = cast<Instruction>(U);
3023+
if (!OrigLoop->contains(UI)) {
3024+
if (isValueIncomingFromBlock(EarlyExitingBB, OrigPhi, UI))
3025+
MissingVals[UI] = FixUpPhi(UI, false);
3026+
}
3027+
}
3028+
3029+
for (auto &I : MissingVals) {
3030+
PHINode *PHI = cast<PHINode>(I.first);
3031+
// One corner case we have to handle is two IVs "chasing" each-other,
3032+
// that is %IV2 = phi [...], [ %IV1, %latch ]
3033+
// In this case, if IV1 has an external use, we need to avoid adding both
3034+
// "last value of IV1" and "penultimate value of IV2". So, verify that we
3035+
// don't already have an incoming value for the middle block.
3036+
if (PHI->getBasicBlockIndex(VectorEarlyExitBB) == -1)
3037+
PHI->addIncoming(I.second, VectorEarlyExitBB);
3038+
}
3039+
}
3040+
29323041
namespace {
29333042

29343043
struct CSEDenseMapInfo {
@@ -3062,6 +3171,13 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30623171
OuterLoop->addBasicBlockToLoop(MiddleSplitBB, *LI);
30633172
PredVPBB = PredVPBB->getSinglePredecessor();
30643173
}
3174+
3175+
BasicBlock *OrigEarlyExitBB = Legal->getUncountableEarlyExitBlock();
3176+
if (Loop *EEL = LI->getLoopFor(OrigEarlyExitBB)) {
3177+
BasicBlock *VectorEarlyExitBB =
3178+
State.CFG.VPBB2IRBB[VectorRegion->getEarlyExit()];
3179+
EEL->addBasicBlockToLoop(VectorEarlyExitBB, *LI);
3180+
}
30653181
}
30663182

30673183
// After vectorization, the exit blocks of the original loop will have
@@ -3091,6 +3207,15 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30913207
getOrCreateVectorTripCount(nullptr), LoopMiddleBlock, State);
30923208
}
30933209

3210+
if (Legal->hasUncountableEarlyExit()) {
3211+
VPBasicBlock *VectorEarlyExitVPBB =
3212+
cast<VPBasicBlock>(VectorRegion->getEarlyExit());
3213+
BasicBlock *VectorEarlyExitBB = State.CFG.VPBB2IRBB[VectorEarlyExitVPBB];
3214+
for (const auto &Entry : Legal->getInductionVars())
3215+
fixupEarlyExitIVUsers(Entry.first, Entry.second, VectorEarlyExitBB,
3216+
LoopMiddleBlock, Plan, State);
3217+
}
3218+
30943219
for (Instruction *PI : PredicatedInstructions)
30953220
sinkScalarOperands(&*PI);
30963221

@@ -8974,6 +9099,9 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
89749099
auto *VectorPhiR = cast<VPHeaderPHIRecipe>(Builder.getRecipe(ScalarPhiI));
89759100
if (!isa<VPFirstOrderRecurrencePHIRecipe, VPReductionPHIRecipe>(VectorPhiR))
89769101
continue;
9102+
assert(!Plan.getVectorLoopRegion()->getEarlyExit() &&
9103+
"Cannot handle "
9104+
"first-order recurrences with uncountable early exits");
89779105
// The backedge value provides the value to resume coming out of a loop,
89789106
// which for FORs is a vector whose last element needs to be extracted. The
89799107
// start value provides the value if the loop is bypassed.
@@ -9032,8 +9160,7 @@ static SetVector<VPIRInstruction *> collectUsersInExitBlocks(
90329160
auto *P = dyn_cast<PHINode>(U);
90339161
return P && Inductions.contains(P);
90349162
}))) {
9035-
if (ExitVPBB->getSinglePredecessor() == MiddleVPBB)
9036-
continue;
9163+
V = VPValue::getNull();
90379164
}
90389165
ExitUsersToFix.insert(ExitIRI);
90399166
ExitIRI->addOperand(V);
@@ -9061,18 +9188,30 @@ addUsersInExitBlocks(VPlan &Plan,
90619188
for (const auto &[Idx, Op] : enumerate(ExitIRI->operands())) {
90629189
// Pass live-in values used by exit phis directly through to their users
90639190
// in the exit block.
9064-
if (Op->isLiveIn())
9191+
if (Op->isLiveIn() || Op->isNull())
90659192
continue;
90669193

90679194
// Currently only live-ins can be used by exit values from blocks not
90689195
// exiting via the vector latch through to the middle block.
9069-
if (ExitIRI->getParent()->getSinglePredecessor() != MiddleVPBB)
9070-
return false;
9071-
90729196
LLVMContext &Ctx = ExitIRI->getInstruction().getContext();
9073-
VPValue *Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9074-
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9075-
IntegerType::get(Ctx, 32), 1))});
9197+
VPValue *Ext;
9198+
VPBasicBlock *PredVPBB =
9199+
cast<VPBasicBlock>(ExitIRI->getParent()->getPredecessors()[Idx]);
9200+
if (PredVPBB != MiddleVPBB) {
9201+
VPBasicBlock *VectorEarlyExitVPBB =
9202+
Plan.getVectorLoopRegion()->getEarlyExit();
9203+
VPBuilder B2(VectorEarlyExitVPBB,
9204+
VectorEarlyExitVPBB->getFirstNonPhi());
9205+
assert(ExitIRI->getParent()->getNumPredecessors() <= 2);
9206+
VPValue *EarlyExitMask =
9207+
Plan.getVectorLoopRegion()->getVectorEarlyExitCond();
9208+
Ext = B2.createNaryOp(VPInstruction::ExtractFirstActive,
9209+
{Op, EarlyExitMask});
9210+
} else {
9211+
Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9212+
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9213+
IntegerType::get(Ctx, 32), 1))});
9214+
}
90769215
ExitIRI->setOperand(Idx, Ext);
90779216
}
90789217
}

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ Value *VPLane::getAsRuntimeExpr(IRBuilderBase &Builder,
8383
llvm_unreachable("Unknown lane kind");
8484
}
8585

86+
static VPValue NullValue;
87+
VPValue *VPValue::Null = &NullValue;
88+
8689
VPValue::VPValue(const unsigned char SC, Value *UV, VPDef *Def)
8790
: SubclassID(SC), UnderlyingVal(UV), Def(Def) {
8891
if (Def)

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,9 @@ class VPInstruction : public VPRecipeWithIRFlags,
12311231
// Returns a scalar boolean value, which is true if any lane of its single
12321232
// operand is true.
12331233
AnyOf,
1234+
// Extracts the first active lane of a vector, where the first operand is
1235+
// the predicate, and the second operand is the vector to extract.
1236+
ExtractFirstActive,
12341237
};
12351238

12361239
private:
@@ -3662,6 +3665,13 @@ class VPRegionBlock : public VPBlockBase {
36623665
/// VPRegionBlock.
36633666
VPBlockBase *Exiting;
36643667

3668+
/// Hold the Early Exit block of the SEME region, if one exists.
3669+
VPBasicBlock *EarlyExit;
3670+
3671+
/// If one exists, this keeps track of the vector early mask that triggered
3672+
/// the early exit.
3673+
VPValue *VectorEarlyExitCond;
3674+
36653675
/// An indicator whether this region is to generate multiple replicated
36663676
/// instances of output IR corresponding to its VPBlockBases.
36673677
bool IsReplicator;
@@ -3670,6 +3680,7 @@ class VPRegionBlock : public VPBlockBase {
36703680
VPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting,
36713681
const std::string &Name = "", bool IsReplicator = false)
36723682
: VPBlockBase(VPRegionBlockSC, Name), Entry(Entry), Exiting(Exiting),
3683+
EarlyExit(nullptr), VectorEarlyExitCond(nullptr),
36733684
IsReplicator(IsReplicator) {
36743685
assert(Entry->getPredecessors().empty() && "Entry block has predecessors.");
36753686
assert(Exiting->getSuccessors().empty() && "Exit block has successors.");
@@ -3678,6 +3689,7 @@ class VPRegionBlock : public VPBlockBase {
36783689
}
36793690
VPRegionBlock(const std::string &Name = "", bool IsReplicator = false)
36803691
: VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exiting(nullptr),
3692+
EarlyExit(nullptr), VectorEarlyExitCond(nullptr),
36813693
IsReplicator(IsReplicator) {}
36823694

36833695
~VPRegionBlock() override {
@@ -3717,6 +3729,22 @@ class VPRegionBlock : public VPBlockBase {
37173729
ExitingBlock->setParent(this);
37183730
}
37193731

3732+
/// Sets the early exit vector mask.
3733+
void setVectorEarlyExitCond(VPValue *V) {
3734+
assert(!VectorEarlyExitCond);
3735+
VectorEarlyExitCond = V;
3736+
}
3737+
3738+
/// Gets the early exit vector mask
3739+
VPValue *getVectorEarlyExitCond() const { return VectorEarlyExitCond; }
3740+
3741+
/// Set the vector early exit block
3742+
void setEarlyExit(VPBasicBlock *ExitBlock) { EarlyExit = ExitBlock; }
3743+
3744+
/// Get the vector early exit block
3745+
const VPBasicBlock *getEarlyExit() const { return EarlyExit; }
3746+
VPBasicBlock *getEarlyExit() { return EarlyExit; }
3747+
37203748
/// Returns the pre-header VPBasicBlock of the loop region.
37213749
VPBasicBlock *getPreheaderVPBB() {
37223750
assert(!isReplicator() && "should only get pre-header of loop regions");

0 commit comments

Comments
 (0)