@@ -563,6 +563,11 @@ class InnerLoopVectorizer {
563
563
Value *VectorTripCount, BasicBlock *MiddleBlock,
564
564
VPTransformState &State);
565
565
566
+ void fixupEarlyExitIVUsers(PHINode *OrigPhi, const InductionDescriptor &II,
567
+ BasicBlock *VectorEarlyExitBB,
568
+ BasicBlock *MiddleBlock, VPlan &Plan,
569
+ VPTransformState &State);
570
+
566
571
/// Iteratively sink the scalarized operands of a predicated instruction into
567
572
/// the block that was created for it.
568
573
void sinkScalarOperands(Instruction *PredInst);
@@ -2838,6 +2843,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
2838
2843
return LoopVectorPreHeader;
2839
2844
}
2840
2845
2846
+ static bool isValueIncomingFromBlock(BasicBlock *ExitingBB, Value *V,
2847
+ Instruction *UI) {
2848
+ PHINode *PHI = dyn_cast<PHINode>(UI);
2849
+ assert(PHI && "Expected LCSSA form");
2850
+
2851
+ // If this loop has an uncountable early exit then there could be
2852
+ // different users of OrigPhi with either:
2853
+ // 1. Multiple users, because each exiting block (countable or
2854
+ // uncountable) jumps to the same exit block, or ..
2855
+ // 2. A single user with an incoming value from a countable or
2856
+ // uncountable exiting block.
2857
+ // In both cases there is no guarantee this came from a countable exiting
2858
+ // block, i.e. the latch.
2859
+ int Index = PHI->getBasicBlockIndex(ExitingBB);
2860
+ return Index != -1 && PHI->getIncomingValue(Index) == V;
2861
+ }
2862
+
2841
2863
// Fix up external users of the induction variable. At this point, we are
2842
2864
// in LCSSA form, with all external PHIs that use the IV having one input value,
2843
2865
// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2853,19 +2875,20 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
2853
2875
// We allow both, but they, obviously, have different values.
2854
2876
2855
2877
DenseMap<Value *, Value *> MissingVals;
2878
+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
2856
2879
2857
2880
Value *EndValue = cast<PHINode>(OrigPhi->getIncomingValueForBlock(
2858
2881
OrigLoop->getLoopPreheader()))
2859
2882
->getIncomingValueForBlock(MiddleBlock);
2860
2883
2861
2884
// An external user of the last iteration's value should see the value that
2862
2885
// the remainder loop uses to initialize its own IV.
2863
- Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch() );
2886
+ Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch );
2864
2887
for (User *U : PostInc->users()) {
2865
2888
Instruction *UI = cast<Instruction>(U);
2866
2889
if (!OrigLoop->contains(UI)) {
2867
- assert(isa<PHINode>(UI) && "Expected LCSSA form");
2868
- MissingVals[UI ] = EndValue;
2890
+ if (isValueIncomingFromBlock(OrigLoopLatch, PostInc, UI))
2891
+ MissingVals[cast<PHINode>(UI) ] = EndValue;
2869
2892
}
2870
2893
}
2871
2894
@@ -2875,7 +2898,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
2875
2898
for (User *U : OrigPhi->users()) {
2876
2899
auto *UI = cast<Instruction>(U);
2877
2900
if (!OrigLoop->contains(UI)) {
2878
- assert(isa<PHINode>(UI) && "Expected LCSSA form");
2901
+ if (!isValueIncomingFromBlock(OrigLoopLatch, OrigPhi, UI))
2902
+ continue;
2903
+
2879
2904
IRBuilder<> B(MiddleBlock->getTerminator());
2880
2905
2881
2906
// Fast-math-flags propagate from the original induction instruction.
@@ -2905,18 +2930,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
2905
2930
}
2906
2931
}
2907
2932
2908
- assert((MissingVals.empty() ||
2909
- all_of(MissingVals,
2910
- [MiddleBlock, this](const std::pair<Value *, Value *> &P) {
2911
- return all_of(
2912
- predecessors(cast<Instruction>(P.first)->getParent()),
2913
- [MiddleBlock, this](BasicBlock *Pred) {
2914
- return Pred == MiddleBlock ||
2915
- Pred == OrigLoop->getLoopLatch();
2916
- });
2917
- })) &&
2918
- "Expected escaping values from latch/middle.block only");
2919
-
2920
2933
for (auto &I : MissingVals) {
2921
2934
PHINode *PHI = cast<PHINode>(I.first);
2922
2935
// One corner case we have to handle is two IVs "chasing" each-other,
@@ -2929,6 +2942,102 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
2929
2942
}
2930
2943
}
2931
2944
2945
+ void InnerLoopVectorizer::fixupEarlyExitIVUsers(PHINode *OrigPhi,
2946
+ const InductionDescriptor &II,
2947
+ BasicBlock *VectorEarlyExitBB,
2948
+ BasicBlock *MiddleBlock,
2949
+ VPlan &Plan,
2950
+ VPTransformState &State) {
2951
+ // There are two kinds of external IV usages - those that use the value
2952
+ // computed in the last iteration (the PHI) and those that use the penultimate
2953
+ // value (the value that feeds into the phi from the loop latch).
2954
+ // We allow both, but they, obviously, have different values.
2955
+ DenseMap<Value *, Value *> MissingVals;
2956
+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
2957
+ BasicBlock *EarlyExitingBB = Legal->getUncountableEarlyExitingBlock();
2958
+ Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
2959
+
2960
+ // Obtain the canonical IV, since we have to use the most recent value
2961
+ // before exiting the loop early. This is unlike fixupIVUsers, which has
2962
+ // the luxury of using the end value in the middle block.
2963
+ VPBasicBlock *EntryVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
2964
+ // NOTE: We cannot call Plan.getCanonicalIV() here because the original
2965
+ // recipe created whilst building plans is no longer valid.
2966
+ VPHeaderPHIRecipe *CanonicalIVR =
2967
+ cast<VPHeaderPHIRecipe>(&*EntryVPBB->begin());
2968
+ Value *CanonicalIV = State.get(CanonicalIVR->getVPSingleValue(), true);
2969
+
2970
+ // Search for the mask that drove us to exit early.
2971
+ VPBasicBlock *EarlyExitVPBB = Plan.getVectorLoopRegion()->getEarlyExit();
2972
+ VPBasicBlock *MiddleSplitVPBB =
2973
+ cast<VPBasicBlock>(EarlyExitVPBB->getSinglePredecessor());
2974
+ VPInstruction *BranchOnCond =
2975
+ cast<VPInstruction>(MiddleSplitVPBB->getTerminator());
2976
+ assert(BranchOnCond->getOpcode() == VPInstruction::BranchOnCond &&
2977
+ "Expected middle.split block terminator to be a branch-on-cond");
2978
+ VPInstruction *ScalarEarlyExitCond =
2979
+ cast<VPInstruction>(BranchOnCond->getOperand(0));
2980
+ assert(
2981
+ ScalarEarlyExitCond->getOpcode() == VPInstruction::AnyOf &&
2982
+ "Expected middle.split block terminator branch condition to be any-of");
2983
+ VPValue *VectorEarlyExitCond = ScalarEarlyExitCond->getOperand(0);
2984
+ // Finally get the mask that led us into the early exit block.
2985
+ Value *EarlyExitMask = State.get(VectorEarlyExitCond);
2986
+
2987
+ // Calculate the IV step.
2988
+ VPValue *StepVPV = Plan.getSCEVExpansion(II.getStep());
2989
+ assert(StepVPV && "step must have been expanded during VPlan execution");
2990
+ Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue()
2991
+ : State.get(StepVPV, VPLane(0));
2992
+
2993
+ auto FixUpPhi = [&](Instruction *UI, bool PostInc) -> Value * {
2994
+ IRBuilder<> B(VectorEarlyExitBB->getTerminator());
2995
+ assert(isa<PHINode>(UI) && "Expected LCSSA form");
2996
+
2997
+ // Fast-math-flags propagate from the original induction instruction.
2998
+ if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp()))
2999
+ B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags());
3000
+
3001
+ Type *CtzType = CanonicalIV->getType();
3002
+ Value *Ctz = B.CreateCountTrailingZeroElems(CtzType, EarlyExitMask);
3003
+ Ctz = B.CreateAdd(Ctz, cast<PHINode>(CanonicalIV));
3004
+ if (PostInc)
3005
+ Ctz = B.CreateAdd(Ctz, ConstantInt::get(CtzType, 1));
3006
+
3007
+ Value *Escape = emitTransformedIndex(B, Ctz, II.getStartValue(), Step,
3008
+ II.getKind(), II.getInductionBinOp());
3009
+ Escape->setName("ind.early.escape");
3010
+ return Escape;
3011
+ };
3012
+
3013
+ for (User *U : PostInc->users()) {
3014
+ auto *UI = cast<Instruction>(U);
3015
+ if (!OrigLoop->contains(UI)) {
3016
+ if (isValueIncomingFromBlock(EarlyExitingBB, PostInc, UI))
3017
+ MissingVals[UI] = FixUpPhi(UI, true);
3018
+ }
3019
+ }
3020
+
3021
+ for (User *U : OrigPhi->users()) {
3022
+ auto *UI = cast<Instruction>(U);
3023
+ if (!OrigLoop->contains(UI)) {
3024
+ if (isValueIncomingFromBlock(EarlyExitingBB, OrigPhi, UI))
3025
+ MissingVals[UI] = FixUpPhi(UI, false);
3026
+ }
3027
+ }
3028
+
3029
+ for (auto &I : MissingVals) {
3030
+ PHINode *PHI = cast<PHINode>(I.first);
3031
+ // One corner case we have to handle is two IVs "chasing" each-other,
3032
+ // that is %IV2 = phi [...], [ %IV1, %latch ]
3033
+ // In this case, if IV1 has an external use, we need to avoid adding both
3034
+ // "last value of IV1" and "penultimate value of IV2". So, verify that we
3035
+ // don't already have an incoming value for the middle block.
3036
+ if (PHI->getBasicBlockIndex(VectorEarlyExitBB) == -1)
3037
+ PHI->addIncoming(I.second, VectorEarlyExitBB);
3038
+ }
3039
+ }
3040
+
2932
3041
namespace {
2933
3042
2934
3043
struct CSEDenseMapInfo {
@@ -3062,6 +3171,13 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
3062
3171
OuterLoop->addBasicBlockToLoop(MiddleSplitBB, *LI);
3063
3172
PredVPBB = PredVPBB->getSinglePredecessor();
3064
3173
}
3174
+
3175
+ BasicBlock *OrigEarlyExitBB = Legal->getUncountableEarlyExitBlock();
3176
+ if (Loop *EEL = LI->getLoopFor(OrigEarlyExitBB)) {
3177
+ BasicBlock *VectorEarlyExitBB =
3178
+ State.CFG.VPBB2IRBB[VectorRegion->getEarlyExit()];
3179
+ EEL->addBasicBlockToLoop(VectorEarlyExitBB, *LI);
3180
+ }
3065
3181
}
3066
3182
3067
3183
// After vectorization, the exit blocks of the original loop will have
@@ -3091,6 +3207,15 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
3091
3207
getOrCreateVectorTripCount(nullptr), LoopMiddleBlock, State);
3092
3208
}
3093
3209
3210
+ if (Legal->hasUncountableEarlyExit()) {
3211
+ VPBasicBlock *VectorEarlyExitVPBB =
3212
+ cast<VPBasicBlock>(VectorRegion->getEarlyExit());
3213
+ BasicBlock *VectorEarlyExitBB = State.CFG.VPBB2IRBB[VectorEarlyExitVPBB];
3214
+ for (const auto &Entry : Legal->getInductionVars())
3215
+ fixupEarlyExitIVUsers(Entry.first, Entry.second, VectorEarlyExitBB,
3216
+ LoopMiddleBlock, Plan, State);
3217
+ }
3218
+
3094
3219
for (Instruction *PI : PredicatedInstructions)
3095
3220
sinkScalarOperands(&*PI);
3096
3221
@@ -8974,6 +9099,9 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
8974
9099
auto *VectorPhiR = cast<VPHeaderPHIRecipe>(Builder.getRecipe(ScalarPhiI));
8975
9100
if (!isa<VPFirstOrderRecurrencePHIRecipe, VPReductionPHIRecipe>(VectorPhiR))
8976
9101
continue;
9102
+ assert(!Plan.getVectorLoopRegion()->getEarlyExit() &&
9103
+ "Cannot handle "
9104
+ "first-order recurrences with uncountable early exits");
8977
9105
// The backedge value provides the value to resume coming out of a loop,
8978
9106
// which for FORs is a vector whose last element needs to be extracted. The
8979
9107
// start value provides the value if the loop is bypassed.
@@ -9032,8 +9160,7 @@ static SetVector<VPIRInstruction *> collectUsersInExitBlocks(
9032
9160
auto *P = dyn_cast<PHINode>(U);
9033
9161
return P && Inductions.contains(P);
9034
9162
}))) {
9035
- if (ExitVPBB->getSinglePredecessor() == MiddleVPBB)
9036
- continue;
9163
+ V = VPValue::getNull();
9037
9164
}
9038
9165
ExitUsersToFix.insert(ExitIRI);
9039
9166
ExitIRI->addOperand(V);
@@ -9061,18 +9188,30 @@ addUsersInExitBlocks(VPlan &Plan,
9061
9188
for (const auto &[Idx, Op] : enumerate(ExitIRI->operands())) {
9062
9189
// Pass live-in values used by exit phis directly through to their users
9063
9190
// in the exit block.
9064
- if (Op->isLiveIn())
9191
+ if (Op->isLiveIn() || Op->isNull() )
9065
9192
continue;
9066
9193
9067
9194
// Currently only live-ins can be used by exit values from blocks not
9068
9195
// exiting via the vector latch through to the middle block.
9069
- if (ExitIRI->getParent()->getSinglePredecessor() != MiddleVPBB)
9070
- return false;
9071
-
9072
9196
LLVMContext &Ctx = ExitIRI->getInstruction().getContext();
9073
- VPValue *Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9074
- {Op, Plan.getOrAddLiveIn(ConstantInt::get(
9075
- IntegerType::get(Ctx, 32), 1))});
9197
+ VPValue *Ext;
9198
+ VPBasicBlock *PredVPBB =
9199
+ cast<VPBasicBlock>(ExitIRI->getParent()->getPredecessors()[Idx]);
9200
+ if (PredVPBB != MiddleVPBB) {
9201
+ VPBasicBlock *VectorEarlyExitVPBB =
9202
+ Plan.getVectorLoopRegion()->getEarlyExit();
9203
+ VPBuilder B2(VectorEarlyExitVPBB,
9204
+ VectorEarlyExitVPBB->getFirstNonPhi());
9205
+ assert(ExitIRI->getParent()->getNumPredecessors() <= 2);
9206
+ VPValue *EarlyExitMask =
9207
+ Plan.getVectorLoopRegion()->getVectorEarlyExitCond();
9208
+ Ext = B2.createNaryOp(VPInstruction::ExtractFirstActive,
9209
+ {Op, EarlyExitMask});
9210
+ } else {
9211
+ Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9212
+ {Op, Plan.getOrAddLiveIn(ConstantInt::get(
9213
+ IntegerType::get(Ctx, 32), 1))});
9214
+ }
9076
9215
ExitIRI->setOperand(Idx, Ext);
9077
9216
}
9078
9217
}
0 commit comments