Skip to content

Commit 2839aa9

Browse files
nikictru
authored andcommitted
[SimpleLoopUnswitch] Fix exponential unswitch
When unswitching via invariant condition injection, we currently mark the condition in the old loop, so that it does not get unswitched again. However, if there are multiple branches for which conditions can be injected, then we can do that for both the old and new loop. This means that the number of unswitches increases exponentially. Change the handling to be more similar to partial unswitching, where we instead mark the whole loop, rather than a single condition. This means that we will only generate a linear number of loops. TBH I think even that is still highly undesirable, and we should probably be unswitching all candidates at the same time, so that we end up with only two loops. But at least this mitigates the worst case. The test case is a reduced variant that generates 1700 lines of IR without this patch and 290 with it. Fixes llvm/llvm-project#66868. (cherry picked from commit 8362cae71b80bc43c8c680cdfb13c495705a622f)
1 parent 773f136 commit 2839aa9

File tree

3 files changed

+329
-47
lines changed

3 files changed

+329
-47
lines changed

llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,9 +2126,10 @@ static void unswitchNontrivialInvariants(
21262126
Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
21272127
IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
21282128
AssumptionCache &AC,
2129-
function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
2129+
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
21302130
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
2131-
function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze) {
2131+
function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze,
2132+
bool InjectedCondition) {
21322133
auto *ParentBB = TI.getParent();
21332134
BranchInst *BI = dyn_cast<BranchInst>(&TI);
21342135
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
@@ -2581,7 +2582,7 @@ static void unswitchNontrivialInvariants(
25812582
for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
25822583
if (UpdatedL->getParentLoop() == ParentL)
25832584
SibLoops.push_back(UpdatedL);
2584-
UnswitchCB(IsStillLoop, PartiallyInvariant, SibLoops);
2585+
UnswitchCB(IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops);
25852586

25862587
if (MSSAU && VerifyMemorySSA)
25872588
MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -2979,13 +2980,6 @@ static bool shouldTryInjectInvariantCondition(
29792980
/// the metadata.
29802981
bool shouldTryInjectBasingOnMetadata(const BranchInst *BI,
29812982
const BasicBlock *TakenSucc) {
2982-
// Skip branches that have already been unswithed this way. After successful
2983-
// unswitching of injected condition, we will still have a copy of this loop
2984-
// which looks exactly the same as original one. To prevent the 2nd attempt
2985-
// of unswitching it in the same pass, mark this branch as "nothing to do
2986-
// here".
2987-
if (BI->hasMetadata("llvm.invariant.condition.injection.disabled"))
2988-
return false;
29892983
SmallVector<uint32_t> Weights;
29902984
if (!extractBranchWeights(*BI, Weights))
29912985
return false;
@@ -3068,13 +3062,9 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L,
30683062
Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock);
30693063

30703064
Builder.SetInsertPoint(CheckBlock);
3071-
auto *NewTerm = Builder.CreateCondBr(TI->getCondition(), TI->getSuccessor(0),
3072-
TI->getSuccessor(1));
3073-
3065+
Builder.CreateCondBr(TI->getCondition(), TI->getSuccessor(0),
3066+
TI->getSuccessor(1));
30743067
TI->eraseFromParent();
3075-
// Prevent infinite unswitching.
3076-
NewTerm->setMetadata("llvm.invariant.condition.injection.disabled",
3077-
MDNode::get(BB->getContext(), {}));
30783068

30793069
// Fixup phis.
30803070
for (auto &I : *InLoopSucc) {
@@ -3442,7 +3432,7 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
34423432
static bool unswitchBestCondition(
34433433
Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
34443434
AAResults &AA, TargetTransformInfo &TTI,
3445-
function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
3435+
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
34463436
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
34473437
function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
34483438
// Collect all invariant conditions within this loop (as opposed to an inner
@@ -3452,9 +3442,10 @@ static bool unswitchBestCondition(
34523442
Instruction *PartialIVCondBranch = nullptr;
34533443
collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
34543444
PartialIVCondBranch, L, LI, AA, MSSAU);
3455-
collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
3456-
PartialIVCondBranch, L, DT, LI, AA,
3457-
MSSAU);
3445+
if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.injection.disable"))
3446+
collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
3447+
PartialIVCondBranch, L, DT, LI, AA,
3448+
MSSAU);
34583449
// If we didn't find any candidates, we're done.
34593450
if (UnswitchCandidates.empty())
34603451
return false;
@@ -3475,8 +3466,11 @@ static bool unswitchBestCondition(
34753466
return false;
34763467
}
34773468

3478-
if (Best.hasPendingInjection())
3469+
bool InjectedCondition = false;
3470+
if (Best.hasPendingInjection()) {
34793471
Best = injectPendingInvariantConditions(Best, L, DT, LI, AC, MSSAU);
3472+
InjectedCondition = true;
3473+
}
34803474
assert(!Best.hasPendingInjection() &&
34813475
"All injections should have been done by now!");
34823476

@@ -3504,7 +3498,7 @@ static bool unswitchBestCondition(
35043498
<< ") terminator: " << *Best.TI << "\n");
35053499
unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
35063500
LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
3507-
InsertFreeze);
3501+
InsertFreeze, InjectedCondition);
35083502
return true;
35093503
}
35103504

@@ -3533,7 +3527,7 @@ static bool
35333527
unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
35343528
AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
35353529
bool NonTrivial,
3536-
function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB,
3530+
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
35373531
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
35383532
ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
35393533
function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
@@ -3548,7 +3542,8 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
35483542
if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
35493543
// If we unswitched successfully we will want to clean up the loop before
35503544
// processing it further so just mark it as unswitched and return.
3551-
UnswitchCB(/*CurrentLoopValid*/ true, false, {});
3545+
UnswitchCB(/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
3546+
/*InjectedCondition*/ false, {});
35523547
return true;
35533548
}
35543549

@@ -3644,6 +3639,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
36443639

36453640
auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
36463641
bool PartiallyInvariant,
3642+
bool InjectedCondition,
36473643
ArrayRef<Loop *> NewLoops) {
36483644
// If we did a non-trivial unswitch, we have added new (cloned) loops.
36493645
if (!NewLoops.empty())
@@ -3663,6 +3659,16 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
36633659
Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
36643660
{DisableUnswitchMD});
36653661
L.setLoopID(NewLoopID);
3662+
} else if (InjectedCondition) {
3663+
// Do the same for injection of invariant conditions.
3664+
auto &Context = L.getHeader()->getContext();
3665+
MDNode *DisableUnswitchMD = MDNode::get(
3666+
Context,
3667+
MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
3668+
MDNode *NewLoopID = makePostTransformationMetadata(
3669+
Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
3670+
{DisableUnswitchMD});
3671+
L.setLoopID(NewLoopID);
36663672
} else
36673673
U.revisitCurrentLoop();
36683674
} else
@@ -3755,6 +3761,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
37553761
auto *SE = SEWP ? &SEWP->getSE() : nullptr;
37563762

37573763
auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, bool PartiallyInvariant,
3764+
bool InjectedCondition,
37583765
ArrayRef<Loop *> NewLoops) {
37593766
// If we did a non-trivial unswitch, we have added new (cloned) loops.
37603767
for (auto *NewL : NewLoops)
@@ -3765,9 +3772,9 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
37653772
// but it is the best we can do in the old PM.
37663773
if (CurrentLoopValid) {
37673774
// If the current loop has been unswitched using a partially invariant
3768-
// condition, we should not re-add the current loop to avoid unswitching
3769-
// on the same condition again.
3770-
if (!PartiallyInvariant)
3775+
// condition or injected invariant condition, we should not re-add the
3776+
// current loop to avoid unswitching on the same condition again.
3777+
if (!PartiallyInvariant && !InjectedCondition)
37713778
LPM.addLoop(*L);
37723779
} else
37733780
LPM.markLoopAsDeleted(*L);

0 commit comments

Comments
 (0)