From 9a64d8a957d207a143c8b6fd4321af3b21b41422 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Wed, 20 Nov 2024 16:55:52 +0800 Subject: [PATCH 1/4] [SCEV] Cache collected loop guards. NFCI This tries to compensate for it by caching the collected loop guards, which gives a -0.07% geomean reduction for stage2-O3: https://llvm-compile-time-tracker.com/compare.php?from=aff98e4be05a1060e489ce62a88ee0ff365e571a&to=198a76db2c0b8fbda5374ffd195731a9d47469e3&stat=instructions:u LoopAccessAnalysis already had a LoopGuards cache for the innermost loop, so this hoists it up into ScalarEvolution. --- .../llvm/Analysis/LoopAccessAnalysis.h | 3 --- llvm/include/llvm/Analysis/ScalarEvolution.h | 5 +++- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++++--------- llvm/lib/Analysis/ScalarEvolution.cpp | 24 ++++++++++--------- 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h index a35bc7402d1a8..872b68f924e65 100644 --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -334,9 +334,6 @@ class MemoryDepChecker { std::pair> PointerBounds; - /// Cache for the loop guards of InnermostLoop. - std::optional LoopGuards; - /// Check whether there is a plausible dependence between the two /// accesses. /// diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 885c5985f9d23..b7b9384c6e642 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1346,7 +1346,6 @@ class ScalarEvolution { /// Try to apply information from loop guards for \p L to \p Expr. const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L); - const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards); /// Return true if the loop has no abnormal exits. That is, if the loop /// is not infinite, it must exit through an explicit edge in the CFG. @@ -1651,6 +1650,10 @@ class ScalarEvolution { /// function as they are computed. DenseMap PredicatedBackedgeTakenCounts; + /// Cache the collected loop guards of the loops of this function as they are + /// computed. + DenseMap LoopGuardsCache; + /// Loops whose backedge taken counts directly use this non-constant SCEV. DenseMap, 4>> BECountUsers; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 907bb7875dc80..9dfc3def3140a 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -1945,16 +1945,13 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize( !isa(SrcEnd_) && !isa(SinkStart_) && !isa(SinkEnd_)) { - if (!LoopGuards) - LoopGuards.emplace( - ScalarEvolution::LoopGuards::collect(InnermostLoop, SE)); - auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards); - auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards); + auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop); + auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop); if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart)) return MemoryDepChecker::Dependence::NoDep; - auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards); - auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards); + auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop); + auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop); if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart)) return MemoryDepChecker::Dependence::NoDep; } @@ -2057,10 +2054,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return Dependence::NoDep; } } else { - if (!LoopGuards) - LoopGuards.emplace( - ScalarEvolution::LoopGuards::collect(InnermostLoop, SE)); - Dist = SE.applyLoopGuards(Dist, *LoopGuards); + Dist = SE.applyLoopGuards(Dist, InnermostLoop); } // Negative distances are not plausible dependencies. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 46b108606f6a6..70a45ef507e27 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8417,6 +8417,7 @@ void ScalarEvolution::forgetAllLoops() { // result. BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + LoopGuardsCache.clear(); BECountUsers.clear(); LoopPropertiesCache.clear(); ConstantEvolutionLoopExitValue.clear(); @@ -10551,9 +10552,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, if (!isLoopInvariant(Step, L)) return getCouldNotCompute(); - LoopGuards Guards = LoopGuards::collect(L, *this); // Specialize step for this loop so we get context sensitive facts below. - const SCEV *StepWLG = applyLoopGuards(Step, Guards); + const SCEV *StepWLG = applyLoopGuards(Step, L); // For positive steps (counting up until unsigned overflow): // N = -Start/Step (as unsigned) @@ -10570,7 +10570,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // N = Distance (as unsigned) if (StepC && (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) { - APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards)); + APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L)); MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance)); // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, @@ -10611,7 +10611,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); const SCEV *ConstantMax = getCouldNotCompute(); if (Exact != getCouldNotCompute()) { - APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards)); + APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L)); ConstantMax = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); } @@ -10629,7 +10629,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, const SCEV *M = E; if (E != getCouldNotCompute()) { - APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards)); + APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L)); M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); } auto *S = isa(E) ? M : E; @@ -13674,6 +13674,7 @@ ScalarEvolution::~ScalarEvolution() { HasRecMap.clear(); BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + LoopGuardsCache.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(PendingPhiRanges.empty() && "getRangeRef garbage"); @@ -15889,10 +15890,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { - return applyLoopGuards(Expr, LoopGuards::collect(L, *this)); -} - -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, - const LoopGuards &Guards) { - return Guards.rewrite(Expr); + auto Itr = LoopGuardsCache.find(L); + if (Itr == LoopGuardsCache.end()) { + LoopGuards Guard = LoopGuards::collect(L, *this); + LoopGuardsCache.insert({L, Guard}); + return Guard.rewrite(Expr); + } + return Itr->second.rewrite(Expr); } From 7a0bf19b515a555f0b80c55e8d700dfcc8e88ff4 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Thu, 21 Nov 2024 14:33:06 +0800 Subject: [PATCH 2/4] Rework to memoize loop guards across multiple exits --- .../llvm/Analysis/LoopAccessAnalysis.h | 3 + llvm/include/llvm/Analysis/ScalarEvolution.h | 122 +++++++++--------- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++- llvm/lib/Analysis/ScalarEvolution.cpp | 119 +++++++++-------- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 9 +- 5 files changed, 150 insertions(+), 119 deletions(-) diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h index 872b68f924e65..a35bc7402d1a8 100644 --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -334,6 +334,9 @@ class MemoryDepChecker { std::pair> PointerBounds; + /// Cache for the loop guards of InnermostLoop. + std::optional LoopGuards; + /// Check whether there is a plausible dependence between the two /// accesses. /// diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index b7b9384c6e642..692dd02f4b432 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1112,6 +1112,46 @@ class ScalarEvolution { bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS); + class LoopGuards { + DenseMap RewriteMap; + bool PreserveNUW = false; + bool PreserveNSW = false; + ScalarEvolution &SE; + + LoopGuards(ScalarEvolution &SE) : SE(SE) {} + + /// Recursively collect loop guards in \p Guards, starting from + /// block \p Block with predecessor \p Pred. The intended starting point + /// is to collect from a loop header and its predecessor. + static void + collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, + const BasicBlock *Block, const BasicBlock *Pred, + SmallPtrSetImpl &VisitedBlocks, + unsigned Depth = 0); + + /// Collect loop guards in \p Guards, starting from PHINode \p + /// Phi, by calling \p collectFromBlock on the incoming blocks of + /// \Phi and trying to merge the found constraints into a single + /// combined one for \p Phi. + static void collectFromPHI( + ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, + const PHINode &Phi, SmallPtrSetImpl &VisitedBlocks, + SmallDenseMap &IncomingGuards, + unsigned Depth); + + public: + /// Collect rewrite map for loop guards for loop \p L, together with flags + /// indicating if NUW and NSW can be preserved during rewriting. + static LoopGuards collect(const Loop *L, ScalarEvolution &SE); + + /// Try to apply the collected loop guards to \p Expr. + const SCEV *rewrite(const SCEV *Expr) const; + }; + + /// Try to apply information from loop guards for \p L to \p Expr. + const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L); + const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards); + /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary /// pair of exact and max expressions that are eventually summarized in @@ -1167,6 +1207,7 @@ class ScalarEvolution { /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates = false); @@ -1308,45 +1349,6 @@ class ScalarEvolution { /// sharpen it. void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags); - class LoopGuards { - DenseMap RewriteMap; - bool PreserveNUW = false; - bool PreserveNSW = false; - ScalarEvolution &SE; - - LoopGuards(ScalarEvolution &SE) : SE(SE) {} - - /// Recursively collect loop guards in \p Guards, starting from - /// block \p Block with predecessor \p Pred. The intended starting point - /// is to collect from a loop header and its predecessor. - static void - collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, - const BasicBlock *Block, const BasicBlock *Pred, - SmallPtrSetImpl &VisitedBlocks, - unsigned Depth = 0); - - /// Collect loop guards in \p Guards, starting from PHINode \p - /// Phi, by calling \p collectFromBlock on the incoming blocks of - /// \Phi and trying to merge the found constraints into a single - /// combined one for \p Phi. - static void collectFromPHI( - ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, - const PHINode &Phi, SmallPtrSetImpl &VisitedBlocks, - SmallDenseMap &IncomingGuards, - unsigned Depth); - - public: - /// Collect rewrite map for loop guards for loop \p L, together with flags - /// indicating if NUW and NSW can be preserved during rewriting. - static LoopGuards collect(const Loop *L, ScalarEvolution &SE); - - /// Try to apply the collected loop guards to \p Expr. - const SCEV *rewrite(const SCEV *Expr) const; - }; - - /// Try to apply information from loop guards for \p L to \p Expr. - const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L); - /// Return true if the loop has no abnormal exits. That is, if the loop /// is not infinite, it must exit through an explicit edge in the CFG. /// (As opposed to either a) throwing out of the function or b) entering a @@ -1650,10 +1652,6 @@ class ScalarEvolution { /// function as they are computed. DenseMap PredicatedBackedgeTakenCounts; - /// Cache the collected loop guards of the loops of this function as they are - /// computed. - DenseMap LoopGuardsCache; - /// Loops whose backedge taken counts directly use this non-constant SCEV. DenseMap, 4>> BECountUsers; @@ -1843,6 +1841,7 @@ class ScalarEvolution { /// this call will try to use a minimal set of SCEV predicates in order to /// return an exact answer. ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + std::function GetLoopGuards, bool IsOnlyExit, bool AllowPredicates = false); // Helper functions for computeExitLimitFromCond to avoid exponential time @@ -1875,17 +1874,17 @@ class ScalarEvolution { using ExitLimitCacheTy = ExitLimitCache; - ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache, - const Loop *L, Value *ExitCond, - bool ExitIfTrue, - bool ControlsOnlyExit, - bool AllowPredicates); - ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const Loop *L, - Value *ExitCond, bool ExitIfTrue, - bool ControlsOnlyExit, - bool AllowPredicates); + ExitLimit computeExitLimitFromCondCached( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, + bool ControlsOnlyExit, bool AllowPredicates); + ExitLimit computeExitLimitFromCondImpl( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, + bool ControlsOnlyExit, bool AllowPredicates); std::optional computeExitLimitFromCondFromBinOp( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates); /// Compute the number of times the backedge of the specified loop will @@ -1894,8 +1893,8 @@ class ScalarEvolution { /// to use a minimal set of SCEV predicates in order to return an exact /// answer. ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, - bool ExitIfTrue, - bool IsSubExpr, + std::function GetLoopGuards, + bool ExitIfTrue, bool IsSubExpr, bool AllowPredicates = false); /// Variant of previous which takes the components representing an ICmp @@ -1904,16 +1903,16 @@ class ScalarEvolution { /// has a materialized ICmp. ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + std::function GetLoopGuards, bool IsSubExpr, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if its exit condition were a switch with a single exiting case /// to ExitingBB. - ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L, - SwitchInst *Switch, - BasicBlock *ExitingBB, - bool IsSubExpr); + ExitLimit computeExitLimitFromSingleExitSwitch( + const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB, + std::function GetLoopGuards, bool IsSubExpr); /// Compute the exit limit of a loop that is controlled by a /// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip @@ -1937,8 +1936,9 @@ class ScalarEvolution { /// value to zero will execute. If not computable, return CouldNotCompute. /// If AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr, - bool AllowPredicates = false); + ExitLimit howFarToZero(const SCEV *V, const Loop *L, + std::function GetLoopGuards, + bool IsSubExpr, bool AllowPredicates = false); /// Return the number of times an exit condition checking the specified /// value for nonzero will execute. If not computable, return diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 9dfc3def3140a..907bb7875dc80 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -1945,13 +1945,16 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize( !isa(SrcEnd_) && !isa(SinkStart_) && !isa(SinkEnd_)) { - auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop); - auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop); + if (!LoopGuards) + LoopGuards.emplace( + ScalarEvolution::LoopGuards::collect(InnermostLoop, SE)); + auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards); + auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards); if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart)) return MemoryDepChecker::Dependence::NoDep; - auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop); - auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop); + auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards); + auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards); if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart)) return MemoryDepChecker::Dependence::NoDep; } @@ -2054,7 +2057,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return Dependence::NoDep; } } else { - Dist = SE.applyLoopGuards(Dist, InnermostLoop); + if (!LoopGuards) + LoopGuards.emplace( + ScalarEvolution::LoopGuards::collect(InnermostLoop, SE)); + Dist = SE.applyLoopGuards(Dist, *LoopGuards); } // Negative distances are not plausible dependencies. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 70a45ef507e27..0ff2c486a0661 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8417,7 +8417,6 @@ void ScalarEvolution::forgetAllLoops() { // result. BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); - LoopGuardsCache.clear(); BECountUsers.clear(); LoopPropertiesCache.clear(); ConstantEvolutionLoopExitValue.clear(); @@ -8807,6 +8806,12 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, const SCEV *MayExitMaxBECount = nullptr; bool MustExitMaxOrZero = false; bool IsOnlyExit = ExitingBlocks.size() == 1; + std::optional LoopGuards; + auto GetLoopGuards = [&LoopGuards, &L, this]() { + if (!LoopGuards) + LoopGuards.emplace(LoopGuards::collect(L, *this)); + return *LoopGuards; + }; // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. @@ -8822,7 +8827,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, continue; } - ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates); + ExitLimit EL = + computeExitLimit(L, ExitBB, GetLoopGuards, IsOnlyExit, AllowPredicates); assert((AllowPredicates || EL.Predicates.empty()) && "Predicated exit limit when predicates are not allowed!"); @@ -8897,6 +8903,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + std::function GetLoopGuards, bool IsOnlyExit, bool AllowPredicates) { assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); // If our exiting block does not dominate the latch, then its connection with @@ -8912,9 +8919,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && "It should have one successor in loop and one exit block!"); // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, - /*ControlsOnlyExit=*/IsOnlyExit, - AllowPredicates); + return computeExitLimitFromCond( + L, BI->getCondition(), GetLoopGuards, ExitIfTrue, + /*ControlsOnlyExit=*/IsOnlyExit, AllowPredicates); } if (SwitchInst *SI = dyn_cast(Term)) { @@ -8928,18 +8935,19 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, } assert(Exit && "Exiting block must have at least one exit"); return computeExitLimitFromSingleExitSwitch( - L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit); + L, SI, Exit, GetLoopGuards, /*ControlsOnlyExit=*/IsOnlyExit); } return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( - const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, - bool AllowPredicates) { + const Loop *L, Value *ExitCond, std::function GetLoopGuards, + bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); - return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, - ControlsOnlyExit, AllowPredicates); + return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards, + ExitIfTrue, ControlsOnlyExit, + AllowPredicates); } std::optional @@ -8975,37 +8983,41 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) return *MaybeEL; - ExitLimit EL = computeExitLimitFromCondImpl( - Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates); + ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, GetLoopGuards, + ExitIfTrue, ControlsOnlyExit, + AllowPredicates); Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL); return EL; } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { // Handle BinOp conditions (And, Or). if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( - Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) + Cache, L, ExitCond, GetLoopGuards, ExitIfTrue, ControlsOnlyExit, + AllowPredicates)) return *LimitFromBinOp; // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) { - ExitLimit EL = - computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit); + ExitLimit EL = computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards, + ExitIfTrue, ControlsOnlyExit); if (EL.hasFullInfo() || !AllowPredicates) return EL; // Try again, but use SCEV predicates this time. - return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, + return computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards, ExitIfTrue, ControlsOnlyExit, /*AllowPredicates=*/true); } @@ -9041,7 +9053,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( if (Offset != 0) LHS = getAddExpr(LHS, getConstant(Offset)); auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), - ControlsOnlyExit, AllowPredicates); + GetLoopGuards, ControlsOnlyExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; } @@ -9052,7 +9065,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( std::optional ScalarEvolution::computeExitLimitFromCondFromBinOp( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. Value *Op0, *Op1; @@ -9069,11 +9083,11 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // br (or Op0 Op1), exit, loop bool EitherMayExit = IsAnd ^ ExitIfTrue; ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, - AllowPredicates); + Cache, L, Op0, GetLoopGuards, ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, - AllowPredicates); + Cache, L, Op1, GetLoopGuards, ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); // Be robust against unsimplified IR for the form "op i1 X, NeutralElement" const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd); @@ -9132,8 +9146,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( - const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, - bool AllowPredicates) { + const Loop *L, ICmpInst *ExitCond, + std::function GetLoopGuards, bool ExitIfTrue, + bool ControlsOnlyExit, bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Pred; if (!ExitIfTrue) @@ -9145,8 +9160,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, - AllowPredicates); + ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, GetLoopGuards, + ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; @@ -9161,7 +9176,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - bool ControlsOnlyExit, bool AllowPredicates) { + std::function GetLoopGuards, bool ControlsOnlyExit, + bool AllowPredicates) { // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); @@ -9249,8 +9265,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( if (isa(RHS)) return RHS; } - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit, - AllowPredicates); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards, + ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; @@ -9332,10 +9348,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, - SwitchInst *Switch, - BasicBlock *ExitingBlock, - bool ControlsOnlyExit) { +ScalarEvolution::computeExitLimitFromSingleExitSwitch( + const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, + std::function GetLoopGuards, bool ControlsOnlyExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -9348,7 +9363,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); + ExitLimit EL = + howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards, ControlsOnlyExit); if (EL.hasAnyInfo()) return EL; @@ -10486,10 +10502,10 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } -ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, - const Loop *L, - bool ControlsOnlyExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, + std::function GetLoopGuards, + bool ControlsOnlyExit, bool AllowPredicates) { // This is only used for loops with a "x != y" exit test. The exit condition // is now expressed as a single expression, V = x-y. So the exit test is @@ -10552,8 +10568,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, if (!isLoopInvariant(Step, L)) return getCouldNotCompute(); + LoopGuards Guards = GetLoopGuards(); // Specialize step for this loop so we get context sensitive facts below. - const SCEV *StepWLG = applyLoopGuards(Step, L); + const SCEV *StepWLG = applyLoopGuards(Step, Guards); // For positive steps (counting up until unsigned overflow): // N = -Start/Step (as unsigned) @@ -10570,7 +10587,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // N = Distance (as unsigned) if (StepC && (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) { - APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L)); + APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards)); MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance)); // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, @@ -10611,7 +10628,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); const SCEV *ConstantMax = getCouldNotCompute(); if (Exact != getCouldNotCompute()) { - APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L)); + APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards)); ConstantMax = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); } @@ -10629,7 +10646,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, const SCEV *M = E; if (E != getCouldNotCompute()) { - APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L)); + APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards)); M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); } auto *S = isa(E) ? M : E; @@ -13674,7 +13691,6 @@ ScalarEvolution::~ScalarEvolution() { HasRecMap.clear(); BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); - LoopGuardsCache.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(PendingPhiRanges.empty() && "getRangeRef garbage"); @@ -15890,11 +15906,10 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { - auto Itr = LoopGuardsCache.find(L); - if (Itr == LoopGuardsCache.end()) { - LoopGuards Guard = LoopGuards::collect(L, *this); - LoopGuardsCache.insert({L, Guard}); - return Guard.rewrite(Expr); - } - return Itr->second.rewrite(Expr); + return applyLoopGuards(Expr, LoopGuards::collect(L, *this)); +} + +const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, + const LoopGuards &Guards) { + return Guards.rewrite(Expr); } diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 8a3e0bc3eb971..62e6d541af5c6 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1335,6 +1335,13 @@ static bool optimizeLoopExitWithUnknownExitCount( Visited.insert(OldCond); Worklist.push_back(OldCond); + std::optional LoopGuards; + auto GetLoopGuards = [&LoopGuards, &L, &SE]() { + if (!LoopGuards) + LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE)); + return *LoopGuards; + }; + auto GoThrough = [&](Value *V) { Value *LHS = nullptr, *RHS = nullptr; if (Inverted) { @@ -1371,7 +1378,7 @@ static bool optimizeLoopExitWithUnknownExitCount( ScalarEvolution::ExitCountKind::SymbolicMaximum) == MaxIter) for (auto *ICmp : LeafConditions) { - auto EL = SE->computeExitLimitFromCond(L, ICmp, Inverted, + auto EL = SE->computeExitLimitFromCond(L, ICmp, GetLoopGuards, Inverted, /*ControlsExit*/ false); const SCEV *ExitMax = EL.SymbolicMaxNotTaken; if (isa(ExitMax)) From 9c8dac0196c3199842606a208e19235bc6505bc2 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Thu, 21 Nov 2024 18:05:58 +0800 Subject: [PATCH 3/4] Use function_ref, return const reference --- llvm/include/llvm/Analysis/ScalarEvolution.h | 40 ++++++++++--------- llvm/lib/Analysis/ScalarEvolution.cpp | 37 ++++++++--------- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 11 ++--- 3 files changed, 46 insertions(+), 42 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 692dd02f4b432..c67cbefd7fb92 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1206,10 +1206,11 @@ class ScalarEvolution { /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, - std::function GetLoopGuards, - bool ExitIfTrue, bool ControlsOnlyExit, - bool AllowPredicates = false); + ExitLimit + computeExitLimitFromCond(const Loop *L, Value *ExitCond, + function_ref GetLoopGuards, + bool ExitIfTrue, bool ControlsOnlyExit, + bool AllowPredicates = false); /// A predicate is said to be monotonically increasing if may go from being /// false to being true as the loop iterates, but never the other way @@ -1841,7 +1842,7 @@ class ScalarEvolution { /// this call will try to use a minimal set of SCEV predicates in order to /// return an exact answer. ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, - std::function GetLoopGuards, + function_ref GetLoopGuards, bool IsOnlyExit, bool AllowPredicates = false); // Helper functions for computeExitLimitFromCond to avoid exponential time @@ -1876,15 +1877,15 @@ class ScalarEvolution { ExitLimit computeExitLimitFromCondCached( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates); ExitLimit computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates); std::optional computeExitLimitFromCondFromBinOp( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates); /// Compute the number of times the backedge of the specified loop will @@ -1892,27 +1893,28 @@ class ScalarEvolution { /// ExitCond and ExitIfTrue. If AllowPredicates is set, this call will try /// to use a minimal set of SCEV predicates in order to return an exact /// answer. - ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, - std::function GetLoopGuards, - bool ExitIfTrue, bool IsSubExpr, - bool AllowPredicates = false); + ExitLimit + computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, + function_ref GetLoopGuards, + bool ExitIfTrue, bool IsSubExpr, + bool AllowPredicates = false); /// Variant of previous which takes the components representing an ICmp /// as opposed to the ICmpInst itself. Note that the prior version can /// return more precise results in some cases and is preferred when caller /// has a materialized ICmp. - ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - std::function GetLoopGuards, - bool IsSubExpr, - bool AllowPredicates = false); + ExitLimit + computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + function_ref GetLoopGuards, + bool IsSubExpr, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if its exit condition were a switch with a single exiting case /// to ExitingBB. ExitLimit computeExitLimitFromSingleExitSwitch( const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB, - std::function GetLoopGuards, bool IsSubExpr); + function_ref GetLoopGuards, bool IsSubExpr); /// Compute the exit limit of a loop that is controlled by a /// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip @@ -1937,7 +1939,7 @@ class ScalarEvolution { /// If AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit howFarToZero(const SCEV *V, const Loop *L, - std::function GetLoopGuards, + function_ref GetLoopGuards, bool IsSubExpr, bool AllowPredicates = false); /// Return the number of times an exit condition checking the specified diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 0ff2c486a0661..d4c4a12786872 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8806,11 +8806,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, const SCEV *MayExitMaxBECount = nullptr; bool MustExitMaxOrZero = false; bool IsOnlyExit = ExitingBlocks.size() == 1; - std::optional LoopGuards; - auto GetLoopGuards = [&LoopGuards, &L, this]() { - if (!LoopGuards) - LoopGuards.emplace(LoopGuards::collect(L, *this)); - return *LoopGuards; + std::optional CachedLoopGuards; + auto GetLoopGuards = [&CachedLoopGuards, &L, this]() -> const LoopGuards & { + if (!CachedLoopGuards) + CachedLoopGuards.emplace(LoopGuards::collect(L, *this)); + return *CachedLoopGuards; }; // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts @@ -8901,10 +8901,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, MaxBECount, MaxOrZero); } -ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, - std::function GetLoopGuards, - bool IsOnlyExit, bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit( + const Loop *L, BasicBlock *ExitingBlock, + function_ref GetLoopGuards, bool IsOnlyExit, + bool AllowPredicates) { assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); // If our exiting block does not dominate the latch, then its connection with // loop's exit limit may be far from trivial. @@ -8942,8 +8942,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( - const Loop *L, Value *ExitCond, std::function GetLoopGuards, - bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { + const Loop *L, Value *ExitCond, + function_ref GetLoopGuards, bool ExitIfTrue, + bool ControlsOnlyExit, bool AllowPredicates) { ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards, ExitIfTrue, ControlsOnlyExit, @@ -8984,7 +8985,7 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit, @@ -9000,7 +9001,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { // Handle BinOp conditions (And, Or). if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( @@ -9066,7 +9067,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( std::optional ScalarEvolution::computeExitLimitFromCondFromBinOp( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. Value *Op0, *Op1; @@ -9147,7 +9148,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst *ExitCond, - std::function GetLoopGuards, bool ExitIfTrue, + function_ref GetLoopGuards, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Pred; @@ -9176,7 +9177,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - std::function GetLoopGuards, bool ControlsOnlyExit, + function_ref GetLoopGuards, bool ControlsOnlyExit, bool AllowPredicates) { // Try to evaluate any dependencies out of the loop. @@ -9350,7 +9351,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromSingleExitSwitch( const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - std::function GetLoopGuards, bool ControlsOnlyExit) { + function_ref GetLoopGuards, bool ControlsOnlyExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -10504,7 +10505,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, - std::function GetLoopGuards, + function_ref GetLoopGuards, bool ControlsOnlyExit, bool AllowPredicates) { // This is only used for loops with a "x != y" exit test. The exit condition diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 62e6d541af5c6..1cf5aca2266c6 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1335,11 +1335,12 @@ static bool optimizeLoopExitWithUnknownExitCount( Visited.insert(OldCond); Worklist.push_back(OldCond); - std::optional LoopGuards; - auto GetLoopGuards = [&LoopGuards, &L, &SE]() { - if (!LoopGuards) - LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE)); - return *LoopGuards; + std::optional CachedLoopGuards; + auto GetLoopGuards = [&CachedLoopGuards, &L, + &SE]() -> const ScalarEvolution::LoopGuards & { + if (!CachedLoopGuards) + CachedLoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE)); + return *CachedLoopGuards; }; auto GoThrough = [&](Value *V) { From a10aad8c6acfcb879f56ab9267f61d0b2c899939 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Thu, 21 Nov 2024 18:40:21 +0800 Subject: [PATCH 4/4] Use reference again and use in howManyLessThans --- llvm/include/llvm/Analysis/ScalarEvolution.h | 4 +++- llvm/lib/Analysis/ScalarEvolution.cpp | 21 ++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index c67cbefd7fb92..756fd0eeeef74 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1960,7 +1960,9 @@ class ScalarEvolution { /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, - bool isSigned, bool ControlsOnlyExit, + bool isSigned, + function_ref GetLoopGuards, + bool ControlsOnlyExit, bool AllowPredicates = false); ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index d4c4a12786872..51ac06121c9c1 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9317,8 +9317,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = ICmpInst::isSigned(Pred); - ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, - AllowPredicates); + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, GetLoopGuards, + ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; @@ -10569,7 +10569,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, if (!isLoopInvariant(Step, L)) return getCouldNotCompute(); - LoopGuards Guards = GetLoopGuards(); + const LoopGuards &Guards = GetLoopGuards(); // Specialize step for this loop so we get context sensitive facts below. const SCEV *StepWLG = applyLoopGuards(Step, Guards); @@ -12928,10 +12928,10 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, getConstant(StrideForMaxBECount) /* Step */); } -ScalarEvolution::ExitLimit -ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, - bool ControlsOnlyExit, bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans( + const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, + function_ref GetLoopGuards, bool ControlsOnlyExit, + bool AllowPredicates) { SmallVector Predicates; const SCEVAddRecExpr *IV = dyn_cast(LHS); @@ -12965,7 +12965,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); Limit = Limit.zext(OuterBitWidth); - return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); + return getUnsignedRangeMax(applyLoopGuards(RHS, GetLoopGuards())) + .ule(Limit); }; auto Flags = AR->getNoWrapFlags(); if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) @@ -13216,8 +13217,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!BECount) { auto canProveRHSGreaterThanEqualStart = [&]() { auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); - const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); + const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, GetLoopGuards()); + const SCEV *GuardedStart = applyLoopGuards(OrigStart, GetLoopGuards()); if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) || isKnownPredicate(CondGE, GuardedRHS, GuardedStart))