Skip to content

Commit df3d70b

Browse files
authored
[Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)
Due to a reviewer request on PR #88385 I have created this patch to add a getPredicatedExitCount function, which is similar to getExitCount except that it uses the predicated backedge taken information. With PR #88385 we will start to care about more loops with multiple exits, and want the ability to query exit counts for a particular exiting block. Such loops may require predicates in order to be vectorised. New tests added here: Analysis/ScalarEvolution/predicated-exit-count.ll
1 parent ef26afc commit df3d70b

File tree

5 files changed

+200
-33
lines changed

5 files changed

+200
-33
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,13 @@ class ScalarEvolution {
871871
const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
872872
ExitCountKind Kind = Exact);
873873

874+
/// Same as above except this uses the predicated backedge taken info and
875+
/// may require predicates.
876+
const SCEV *
877+
getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
878+
SmallVectorImpl<const SCEVPredicate *> *Predicates,
879+
ExitCountKind Kind = Exact);
880+
874881
/// If the specified loop has a predictable backedge-taken count, return it,
875882
/// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
876883
/// the number of times the loop header will be branched to from within the
@@ -1517,6 +1524,10 @@ class ScalarEvolution {
15171524
bool isComplete() const { return IsComplete; }
15181525
const SCEV *getConstantMax() const { return ConstantMax; }
15191526

1527+
const ExitNotTakenInfo *getExitNotTaken(
1528+
const BasicBlock *ExitingBlock,
1529+
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
1530+
15201531
public:
15211532
BackedgeTakenInfo() = default;
15221533
BackedgeTakenInfo(BackedgeTakenInfo &&) = default;
@@ -1563,25 +1574,44 @@ class ScalarEvolution {
15631574
/// Return the number of times this loop exit may fall through to the back
15641575
/// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
15651576
/// this block before this number of iterations, but may exit via another
1566-
/// block.
1567-
const SCEV *getExact(const BasicBlock *ExitingBlock,
1568-
ScalarEvolution *SE) const;
1577+
/// block. If \p Predicates is null the function returns CouldNotCompute if
1578+
/// predicates are required, otherwise it fills in the required predicates.
1579+
const SCEV *getExact(
1580+
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
1581+
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
1582+
if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
1583+
return ENT->ExactNotTaken;
1584+
else
1585+
return SE->getCouldNotCompute();
1586+
}
15691587

15701588
/// Get the constant max backedge taken count for the loop.
15711589
const SCEV *getConstantMax(ScalarEvolution *SE) const;
15721590

15731591
/// Get the constant max backedge taken count for the particular loop exit.
1574-
const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
1575-
ScalarEvolution *SE) const;
1592+
const SCEV *getConstantMax(
1593+
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
1594+
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
1595+
if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
1596+
return ENT->ConstantMaxNotTaken;
1597+
else
1598+
return SE->getCouldNotCompute();
1599+
}
15761600

15771601
/// Get the symbolic max backedge taken count for the loop.
15781602
const SCEV *getSymbolicMax(
15791603
const Loop *L, ScalarEvolution *SE,
15801604
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);
15811605

15821606
/// Get the symbolic max backedge taken count for the particular loop exit.
1583-
const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
1584-
ScalarEvolution *SE) const;
1607+
const SCEV *getSymbolicMax(
1608+
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
1609+
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
1610+
if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
1611+
return ENT->SymbolicMaxNotTaken;
1612+
else
1613+
return SE->getCouldNotCompute();
1614+
}
15851615

15861616
/// Return true if the number of times this backedge is taken is either the
15871617
/// value returned by getConstantMax or zero.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82478247
llvm_unreachable("Invalid ExitCountKind!");
82488248
}
82498249

8250+
const SCEV *ScalarEvolution::getPredicatedExitCount(
8251+
const Loop *L, const BasicBlock *ExitingBlock,
8252+
SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
8253+
switch (Kind) {
8254+
case Exact:
8255+
return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8256+
Predicates);
8257+
case SymbolicMaximum:
8258+
return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8259+
Predicates);
8260+
case ConstantMaximum:
8261+
return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8262+
Predicates);
8263+
};
8264+
llvm_unreachable("Invalid ExitCountKind!");
8265+
}
8266+
82508267
const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
82518268
const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
82528269
return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
@@ -8574,33 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
85748591
return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
85758592
}
85768593

8577-
/// Get the exact not taken count for this loop exit.
8578-
const SCEV *
8579-
ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8580-
ScalarEvolution *SE) const {
8581-
for (const auto &ENT : ExitNotTaken)
8582-
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8583-
return ENT.ExactNotTaken;
8584-
8585-
return SE->getCouldNotCompute();
8586-
}
8587-
8588-
const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8589-
const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8594+
const ScalarEvolution::ExitNotTakenInfo *
8595+
ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8596+
const BasicBlock *ExitingBlock,
8597+
SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
85908598
for (const auto &ENT : ExitNotTaken)
8591-
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8592-
return ENT.ConstantMaxNotTaken;
8593-
8594-
return SE->getCouldNotCompute();
8595-
}
8596-
8597-
const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8598-
const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8599-
for (const auto &ENT : ExitNotTaken)
8600-
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8601-
return ENT.SymbolicMaxNotTaken;
8599+
if (ENT.ExitingBlock == ExitingBlock) {
8600+
if (ENT.hasAlwaysTruePredicate())
8601+
return &ENT;
8602+
else if (Predicates) {
8603+
for (const auto *P : ENT.Predicates)
8604+
Predicates->push_back(P);
8605+
return &ENT;
8606+
}
8607+
}
86028608

8603-
return SE->getCouldNotCompute();
8609+
return nullptr;
86048610
}
86058611

86068612
/// getConstantMax - Get the constant max backedge taken count for the loop.
@@ -13642,7 +13648,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1364213648
if (ExitingBlocks.size() > 1)
1364313649
for (BasicBlock *ExitingBlock : ExitingBlocks) {
1364413650
OS << " exit count for " << ExitingBlock->getName() << ": ";
13645-
PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13651+
const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13652+
PrintSCEVWithTypeHint(OS, EC);
13653+
if (isa<SCEVCouldNotCompute>(EC)) {
13654+
// Retry with predicates.
13655+
SmallVector<const SCEVPredicate *, 4> Predicates;
13656+
EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13657+
if (!isa<SCEVCouldNotCompute>(EC)) {
13658+
OS << "\n predicated exit count for " << ExitingBlock->getName()
13659+
<< ": ";
13660+
PrintSCEVWithTypeHint(OS, EC);
13661+
OS << "\n Predicates:\n";
13662+
for (const auto *P : Predicates)
13663+
P->print(OS, 4);
13664+
}
13665+
}
1364613666
OS << "\n";
1364713667
}
1364813668

@@ -13682,6 +13702,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1368213702
auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
1368313703
ScalarEvolution::SymbolicMaximum);
1368413704
PrintSCEVWithTypeHint(OS, ExitBTC);
13705+
if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13706+
// Retry with predicates.
13707+
SmallVector<const SCEVPredicate *, 4> Predicates;
13708+
ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13709+
ScalarEvolution::SymbolicMaximum);
13710+
if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13711+
OS << "\n predicated symbolic max exit count for "
13712+
<< ExitingBlock->getName() << ": ";
13713+
PrintSCEVWithTypeHint(OS, ExitBTC);
13714+
OS << "\n Predicates:\n";
13715+
for (const auto *P : Predicates)
13716+
P->print(OS, 4);
13717+
}
13718+
}
1368513719
OS << "\n";
1368613720
}
1368713721

llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,25 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
9393
; CHECK-NEXT: Determining loop execution counts for: @ule_from_zero_no_nuw
9494
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
9595
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
96+
; CHECK-NEXT: predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
97+
; CHECK-NEXT: Predicates:
98+
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
99+
; CHECK-EMPTY:
96100
; CHECK-NEXT: exit count for latch: %N
97101
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
98102
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %N
99103
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
104+
; CHECK-NEXT: predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
105+
; CHECK-NEXT: Predicates:
106+
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
107+
; CHECK-EMPTY:
100108
; CHECK-NEXT: symbolic max exit count for latch: %N
101109
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
102110
; CHECK-NEXT: Predicates:
103111
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
112+
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
113+
; CHECK-NEXT: Predicates:
114+
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
104115
;
105116
entry:
106117
br label %loop
@@ -211,14 +222,25 @@ define void @sle_from_int_min_no_nsw(i32 %M, i32 %N) {
211222
; CHECK-NEXT: Determining loop execution counts for: @sle_from_int_min_no_nsw
212223
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
213224
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
225+
; CHECK-NEXT: predicated exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
226+
; CHECK-NEXT: Predicates:
227+
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
228+
; CHECK-EMPTY:
214229
; CHECK-NEXT: exit count for latch: (-2147483648 + %N)
215230
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
216231
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-2147483648 + %N)
217232
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
233+
; CHECK-NEXT: predicated symbolic max exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
234+
; CHECK-NEXT: Predicates:
235+
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
236+
; CHECK-EMPTY:
218237
; CHECK-NEXT: symbolic max exit count for latch: (-2147483648 + %N)
219238
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
220239
; CHECK-NEXT: Predicates:
221240
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
241+
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
242+
; CHECK-NEXT: Predicates:
243+
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
222244
;
223245
entry:
224246
br label %loop
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 < %s 2>&1 | FileCheck %s
3+
4+
5+
define i32 @multiple_exits_with_predicates(ptr %src1, ptr readonly %src2, i32 %end) {
6+
; CHECK-LABEL: 'multiple_exits_with_predicates'
7+
; CHECK-NEXT: Determining loop execution counts for: @multiple_exits_with_predicates
8+
; CHECK-NEXT: Loop %for.body: <multiple exits> Unpredictable backedge-taken count.
9+
; CHECK-NEXT: exit count for for.body: ***COULDNOTCOMPUTE***
10+
; CHECK-NEXT: predicated exit count for for.body: i32 1023
11+
; CHECK-NEXT: Predicates:
12+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
13+
; CHECK-EMPTY:
14+
; CHECK-NEXT: exit count for for.work: ***COULDNOTCOMPUTE***
15+
; CHECK-NEXT: exit count for for.inc: ***COULDNOTCOMPUTE***
16+
; CHECK-NEXT: predicated exit count for for.inc: (-1 + (1 umax %end))
17+
; CHECK-NEXT: Predicates:
18+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
19+
; CHECK-EMPTY:
20+
; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
21+
; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
22+
; CHECK-NEXT: symbolic max exit count for for.body: ***COULDNOTCOMPUTE***
23+
; CHECK-NEXT: predicated symbolic max exit count for for.body: i32 1023
24+
; CHECK-NEXT: Predicates:
25+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
26+
; CHECK-EMPTY:
27+
; CHECK-NEXT: symbolic max exit count for for.work: ***COULDNOTCOMPUTE***
28+
; CHECK-NEXT: symbolic max exit count for for.inc: ***COULDNOTCOMPUTE***
29+
; CHECK-NEXT: predicated symbolic max exit count for for.inc: (-1 + (1 umax %end))
30+
; CHECK-NEXT: Predicates:
31+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
32+
; CHECK-EMPTY:
33+
; CHECK-NEXT: Loop %for.body: Predicated symbolic max backedge-taken count is (1023 umin (-1 + (1 umax %end)))
34+
; CHECK-NEXT: Predicates:
35+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
36+
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
37+
;
38+
entry:
39+
br label %for.body
40+
41+
for.body:
42+
%index = phi i8 [ %index.next, %for.inc ], [ 0, %entry ]
43+
%index.next = add i8 %index, 1
44+
%conv = zext i8 %index.next to i32
45+
%cmp.body = icmp ne i32 %conv, 1024
46+
br i1 %cmp.body, label %for.work, label %exit
47+
48+
for.work:
49+
%arrayidx = getelementptr inbounds i32, ptr %src1, i8 %index
50+
%0 = load i32, ptr %arrayidx, align 4
51+
%arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %index
52+
%1 = load i32, ptr %arrayidx3, align 4
53+
%cmp.work = icmp eq i32 %0, %1
54+
br i1 %cmp.work, label %found, label %for.inc
55+
56+
for.inc:
57+
%cmp.inc = icmp ult i32 %conv, %end
58+
br i1 %cmp.inc, label %for.body, label %exit
59+
60+
found:
61+
ret i32 1
62+
63+
exit:
64+
ret i32 0
65+
}

llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
88
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
99
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
1010
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
11+
; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
12+
; CHECK-NEXT: Predicates:
13+
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
14+
; CHECK-EMPTY:
1115
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
1216
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
1317
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
1418
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
19+
; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
20+
; CHECK-NEXT: Predicates:
21+
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
22+
; CHECK-EMPTY:
1523
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
1624
; CHECK-NEXT: Predicates:
1725
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
@@ -51,10 +59,18 @@ define void @test2(i64 %x, ptr %a) {
5159
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
5260
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
5361
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
62+
; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
63+
; CHECK-NEXT: Predicates:
64+
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
65+
; CHECK-EMPTY:
5466
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
5567
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
5668
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
5769
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
70+
; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
71+
; CHECK-NEXT: Predicates:
72+
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
73+
; CHECK-EMPTY:
5874
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
5975
; CHECK-NEXT: Predicates:
6076
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>

0 commit comments

Comments
 (0)