Skip to content

[VPlan] Introduce multi-branch recipe, use for multi-exit loops (WIP). #109193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ class LoopVectorizationLegality {
/// we can use in-order reductions.
bool canVectorizeFPMath(bool EnableStrictReductions);

bool canVectorizeMultiCond() const;

/// Return true if we can vectorize this loop while folding its tail by
/// masking.
bool canFoldTailByMasking() const;
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
cl::desc("Enable recognition of non-constant strided "
"pointer induction variables."));

static cl::opt<bool> EnableMultiCond("enable-multi-cond-vectorization",
cl::init(false), cl::Hidden, cl::desc(""));

namespace llvm {
cl::opt<bool>
HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
Expand Down Expand Up @@ -1378,6 +1381,8 @@ bool LoopVectorizationLegality::isFixedOrderRecurrence(
}

bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
if (canVectorizeMultiCond() && BB != TheLoop->getHeader())
return true;
return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
}

Expand Down Expand Up @@ -1514,6 +1519,35 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
return true;
}

bool LoopVectorizationLegality::canVectorizeMultiCond() const {
if (!EnableMultiCond)
return false;
SmallVector<BasicBlock *> Exiting;
TheLoop->getExitingBlocks(Exiting);
if (Exiting.size() != 2 || Exiting[0] != TheLoop->getHeader() ||
Exiting[1] != TheLoop->getLoopLatch() ||
any_of(*TheLoop->getHeader(), [](Instruction &I) {
return I.mayReadFromMemory() || I.mayHaveSideEffects();
}))
return false;
CmpInst::Predicate Pred;
Value *A, *B;
if (!match(
TheLoop->getHeader()->getTerminator(),
m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())) ||
Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)
return false;
if (any_of(TheLoop->getBlocks(), [this](BasicBlock *BB) {
return any_of(*BB, [this](Instruction &I) {
return any_of(I.users(), [this](User *U) {
return !TheLoop->contains(cast<Instruction>(U)->getParent());
});
});
}))
return false;
return true;
}

// Helper function to canVectorizeLoopNestCFG.
bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
bool UseVPlanNativePath) {
Expand Down
82 changes: 55 additions & 27 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1362,9 +1362,11 @@ class LoopVectorizationCostModel {
// If we might exit from anywhere but the latch, must run the exiting
// iteration in scalar form.
if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
LLVM_DEBUG(
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
return true;
if (!Legal->canVectorizeMultiCond()) {
LLVM_DEBUG(
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
return true;
}
}
if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) {
LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: "
Expand Down Expand Up @@ -2535,8 +2537,17 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
LoopVectorPreHeader = OrigLoop->getLoopPreheader();
assert(LoopVectorPreHeader && "Invalid loop structure");
LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
"multiple exit loop without required epilogue?");
if (Legal->canVectorizeMultiCond()) {
BasicBlock *Latch = OrigLoop->getLoopLatch();
BasicBlock *TrueSucc =
cast<BranchInst>(Latch->getTerminator())->getSuccessor(0);
BasicBlock *FalseSucc =
cast<BranchInst>(Latch->getTerminator())->getSuccessor(1);
LoopExitBlock = OrigLoop->contains(TrueSucc) ? FalseSucc : TrueSucc;
} else {
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
"multiple exit loop without required epilogue?");
}

LoopMiddleBlock =
SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT,
Expand Down Expand Up @@ -2910,7 +2921,8 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
for (PHINode &PN : Exit->phis())
PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN);

if (Cost->requiresScalarEpilogue(VF.isVector())) {
if (Legal->canVectorizeMultiCond() ||
Cost->requiresScalarEpilogue(VF.isVector())) {
// No edge from the middle block to the unique exit block has been inserted
// and there is nothing to fix from vector loop; phis should have incoming
// from scalar loop only.
Expand Down Expand Up @@ -3554,7 +3566,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
TheLoop->getExitingBlocks(Exiting);
for (BasicBlock *E : Exiting) {
auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0));
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse() &&
(TheLoop->getLoopLatch() == E || !Legal->canVectorizeMultiCond()))
AddToWorklistIfAllowed(Cmp);
}

Expand Down Expand Up @@ -7643,12 +7656,15 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
BestVPlan.execute(&State);

// 2.5 Collect reduction resume values.
auto *ExitVPBB =
cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
for (VPRecipeBase &R : *ExitVPBB) {
createAndCollectMergePhiForReduction(
dyn_cast<VPInstruction>(&R), State, OrigLoop,
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
VPBasicBlock *ExitVPBB = nullptr;
if (BestVPlan.getVectorLoopRegion()->getSingleSuccessor()) {
ExitVPBB = cast<VPBasicBlock>(
BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
for (VPRecipeBase &R : *ExitVPBB) {
createAndCollectMergePhiForReduction(
dyn_cast<VPInstruction>(&R), State, OrigLoop,
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
}
}

// 2.6. Maintain Loop Hints
Expand All @@ -7674,6 +7690,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
LoopVectorizeHints Hints(L, true, *ORE);
Hints.setAlreadyVectorized();
}

TargetTransformInfo::UnrollingPreferences UP;
TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
Expand All @@ -7686,15 +7703,17 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
ILV.printDebugTracesAtEnd();

// 4. Adjust branch weight of the branch in the middle block.
auto *MiddleTerm =
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
if (MiddleTerm->isConditional() &&
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
assert(TripCount > 0 && "trip count should not be zero");
const uint32_t Weights[] = {1, TripCount - 1};
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
if (ExitVPBB) {
auto *MiddleTerm =
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
if (MiddleTerm->isConditional() &&
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
assert(TripCount > 0 && "trip count should not be zero");
const uint32_t Weights[] = {1, TripCount - 1};
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
}
}

return State.ExpandedSCEVs;
Expand Down Expand Up @@ -8079,7 +8098,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
// If source is an exiting block, we know the exit edge is dynamically dead
// in the vector loop, and thus we don't need to restrict the mask. Avoid
// adding uses of an otherwise potentially dead instruction.
if (OrigLoop->isLoopExiting(Src))
if (!Legal->canVectorizeMultiCond() && OrigLoop->isLoopExiting(Src))
return EdgeMaskCache[Edge] = SrcMask;

VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
Expand Down Expand Up @@ -8729,6 +8748,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
static SetVector<VPIRInstruction *> collectUsersInExitBlock(
Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
const MapVector<PHINode *, InductionDescriptor> &Inductions) {
if (!Plan.getVectorLoopRegion()->getSingleSuccessor())
return {};
auto *MiddleVPBB =
cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
// No edge from the middle block to the unique exit block has been inserted
Expand Down Expand Up @@ -8814,6 +8835,8 @@ static void addLiveOutsForFirstOrderRecurrences(
// TODO: Should be replaced by
// Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the
// scalar region is modeled as well.
if (!VectorRegion->getSingleSuccessor())
return;
auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
VPBasicBlock *ScalarPHVPBB = nullptr;
if (MiddleVPBB->getNumSuccessors() == 2) {
Expand Down Expand Up @@ -9100,6 +9123,11 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
"VPBasicBlock");
RecipeBuilder.fixHeaderPhis();

if (Legal->canVectorizeMultiCond()) {
VPlanTransforms::convertToMultiCond(*Plan, *PSE.getSE(), OrigLoop,
RecipeBuilder);
}

SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
Expand Down Expand Up @@ -9231,8 +9259,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
using namespace VPlanPatternMatch;
VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
for (VPRecipeBase &R : Header->phis()) {
auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
Expand All @@ -9251,8 +9277,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
for (VPUser *U : Cur->users()) {
auto *UserRecipe = cast<VPSingleDefRecipe>(U);
if (!UserRecipe->getParent()->getEnclosingLoopRegion()) {
assert(UserRecipe->getParent() == MiddleVPBB &&
"U must be either in the loop region or the middle block.");
continue;
}
Worklist.insert(UserRecipe);
Expand Down Expand Up @@ -9357,6 +9381,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
}
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
Builder.setInsertPoint(&*LatchVPBB->begin());
if (!VectorLoopRegion->getSingleSuccessor())
return;
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
for (VPRecipeBase &R :
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
Expand Down
48 changes: 37 additions & 11 deletions llvm/lib/Transforms/Vectorize/VPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
// backedges. A backward successor is set when the branch is created.
const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
if (TermBr->getSuccessor(idx) &&
PredVPBlock == getPlan()->getVectorLoopRegion() &&
PredVPBlock->getNumSuccessors()) {
// Update PRedBB and TermBr for BranchOnMultiCond in predecessor.
PredBB = TermBr->getSuccessor(1);
TermBr = cast<BranchInst>(PredBB->getTerminator());
idx = 0;
}
assert(!TermBr->getSuccessor(idx) &&
"Trying to reset an existing successor block.");
TermBr->setSuccessor(idx, IRBB);
Expand Down Expand Up @@ -601,9 +609,11 @@ static bool hasConditionalTerminator(const VPBasicBlock *VPBB) {
}

const VPRecipeBase *R = &VPBB->back();
bool IsCondBranch = isa<VPBranchOnMaskRecipe>(R) ||
match(R, m_BranchOnCond(m_VPValue())) ||
match(R, m_BranchOnCount(m_VPValue(), m_VPValue()));
bool IsCondBranch =
isa<VPBranchOnMaskRecipe>(R) || match(R, m_BranchOnCond(m_VPValue())) ||
match(R, m_BranchOnCount(m_VPValue(), m_VPValue())) ||
(isa<VPInstruction>(R) && cast<VPInstruction>(R)->getOpcode() ==
VPInstruction::BranchMultipleConds);
(void)IsCondBranch;

if (VPBB->getNumSuccessors() >= 2 ||
Expand Down Expand Up @@ -908,8 +918,8 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);

VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
if (!RequiresScalarEpilogueCheck) {
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
return Plan;
}
Expand All @@ -923,10 +933,14 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
// we unconditionally branch to the scalar preheader. Do nothing.
// 3) Otherwise, construct a runtime check.
BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
// The connection order corresponds to the operands of the conditional branch.
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
if (IRExitBlock) {
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
// The connection order corresponds to the operands of the conditional
// branch.
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
}

auto *ScalarLatchTerm = TheLoop->getLoopLatch()->getTerminator();
// Here we use the same DebugLoc as the scalar loop latch terminator instead
Expand Down Expand Up @@ -1035,7 +1049,9 @@ void VPlan::execute(VPTransformState *State) {
// VPlan execution rather than earlier during VPlan construction.
BasicBlock *MiddleBB = State->CFG.ExitBB;
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
getVectorLoopRegion()->getNumSuccessors() == 1
? cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[0])
: cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[1]);
// Find the VPBB for the scalar preheader, relying on the current structure
// when creating the middle block and its successrs: if there's a single
// predecessor, it must be the scalar preheader. Otherwise, the second
Expand All @@ -1060,12 +1076,19 @@ void VPlan::execute(VPTransformState *State) {
State->CFG.DTU.applyUpdates({{DominatorTree::Delete, MiddleBB, ScalarPh}});

// Generate code in the loop pre-header and body.
for (VPBlockBase *Block : vp_depth_first_shallow(Entry))
ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(
Entry);

for (VPBlockBase *Block : RPOT)
Block->execute(State);

VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock();
BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB];

if (!getVectorLoopRegion()->getSingleSuccessor())
VectorLatchBB =
cast<BranchInst>(VectorLatchBB->getTerminator())->getSuccessor(1);

// Fix the latch value of canonical, reduction and first-order recurrences
// phis in the vector loop.
VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock();
Expand All @@ -1092,7 +1115,10 @@ void VPlan::execute(VPTransformState *State) {
// Move the last step to the end of the latch block. This ensures
// consistent placement of all induction updates.
Instruction *Inc = cast<Instruction>(Phi->getIncomingValue(1));
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
if (VectorLatchBB->getTerminator() == &*VectorLatchBB->getFirstNonPHI())
Inc->moveBefore(VectorLatchBB->getTerminator());
else
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());

// Use the steps for the last part as backedge value for the induction.
if (auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R))
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
CanonicalIVIncrementForPart,
BranchOnCount,
BranchOnCond,
BranchMultipleConds,
ComputeReductionResult,
// Takes the VPValue to extract from as first operand and the lane or part
// to extract as second operand, counting from the end starting with 1 for
Expand All @@ -1249,6 +1250,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
// operand). Only generates scalar values (either for the first lane only or
// for all lanes, depending on its uses).
PtrAdd,
AnyOf,
};

private:
Expand Down Expand Up @@ -1370,6 +1372,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
case Instruction::AtomicRMW:
case VPInstruction::BranchOnCond:
case VPInstruction::BranchOnCount:
case VPInstruction::BranchMultipleConds:
return false;
default:
return true;
Expand Down
Loading
Loading