Skip to content

Commit 5d7f84e

Browse files
committed
LoopRotate: Add code to update branch weights
This adds code to the loop rotation transformation to ensure that the computed block execution counts for the loop bodies are the same before and after the transformation. This isn't always true in practice, but I believe this is because of numeric inaccuracies in the BlockFrequency computation. The invariants this is modeled on and heuristic choice of 0-trip loop amount is explained in a lenghty comment in the new `updateBranchWeights()` function. Differential Revision: https://reviews.llvm.org/D157462
1 parent 285e023 commit 5d7f84e

File tree

3 files changed

+259
-5
lines changed

3 files changed

+259
-5
lines changed

llvm/lib/Transforms/Utils/LoopRotationUtils.cpp

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "llvm/IR/DebugInfo.h"
2626
#include "llvm/IR/Dominators.h"
2727
#include "llvm/IR/IntrinsicInst.h"
28+
#include "llvm/IR/MDBuilder.h"
29+
#include "llvm/IR/ProfDataUtils.h"
2830
#include "llvm/Support/CommandLine.h"
2931
#include "llvm/Support/Debug.h"
3032
#include "llvm/Support/raw_ostream.h"
@@ -50,6 +52,9 @@ static cl::opt<bool>
5052
cl::desc("Allow loop rotation multiple times in order to reach "
5153
"a better latch exit"));
5254

55+
// Probability that a rotated loop has zero trip count / is never entered.
56+
static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
57+
5358
namespace {
5459
/// A simple loop rotation transformation.
5560
class LoopRotate {
@@ -244,6 +249,93 @@ static bool canRotateDeoptimizingLatchExit(Loop *L) {
244249
return false;
245250
}
246251

252+
static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
253+
bool HasConditionalPreHeader,
254+
bool SuccsSwapped) {
255+
MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
256+
if (WeightMD == nullptr)
257+
return;
258+
259+
// LoopBI should currently be a clone of PreHeaderBI with the same
260+
// metadata. But we double check to make sure we don't have a degenerate case
261+
// where instsimplify changed the instructions.
262+
if (WeightMD != getBranchWeightMDNode(LoopBI))
263+
return;
264+
265+
SmallVector<uint32_t, 2> Weights;
266+
extractFromBranchWeightMD(WeightMD, Weights);
267+
if (Weights.size() != 2)
268+
return;
269+
uint32_t OrigLoopExitWeight = Weights[0];
270+
uint32_t OrigLoopBackedgeWeight = Weights[1];
271+
272+
if (SuccsSwapped)
273+
std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
274+
275+
// Update branch weights. Consider the following edge-counts:
276+
//
277+
// | |-------- |
278+
// V V | V
279+
// Br i1 ... | Br i1 ...
280+
// | | | | |
281+
// x| y| | becomes: | y0| |-----
282+
// V V | | V V |
283+
// Exit Loop | | Loop |
284+
// | | | Br i1 ... |
285+
// ----- | | | |
286+
// x0| x1| y1 | |
287+
// V V ----
288+
// Exit
289+
//
290+
// The following must hold:
291+
// - x == x0 + x1 # counts to "exit" must stay the same.
292+
// - y0 == x - x0 == x1 # how often loop was entered at all.
293+
// - y1 == y - y0 # How often loop was repeated (after first iter.).
294+
//
295+
// We cannot generally deduce how often we had a zero-trip count loop so we
296+
// have to make a guess for how to distribute x among the new x0 and x1.
297+
298+
uint32_t ExitWeight0 = 0; // aka x0
299+
if (HasConditionalPreHeader) {
300+
// Here we cannot know how many 0-trip count loops we have, so we guess:
301+
if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
302+
// If the loop count is bigger than the exit count then we set
303+
// probabilities as if 0-trip count nearly never happens.
304+
ExitWeight0 = ZeroTripCountWeights[0];
305+
// Scale up counts if necessary so we can match `ZeroTripCountWeights` for
306+
// the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
307+
while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
308+
// ... but don't overflow.
309+
uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
310+
if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
311+
(OrigLoopExitWeight & HighBit) != 0)
312+
break;
313+
OrigLoopBackedgeWeight <<= 1;
314+
OrigLoopExitWeight <<= 1;
315+
}
316+
} else {
317+
// If there's a higher exit-count than backedge-count then we set
318+
// probabilities as if there are only 0-trip and 1-trip cases.
319+
ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
320+
}
321+
}
322+
uint32_t ExitWeight1 = OrigLoopExitWeight - ExitWeight0; // aka x1
323+
uint32_t EnterWeight = ExitWeight1; // aka y0
324+
uint32_t LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; // aka y1
325+
326+
MDBuilder MDB(LoopBI.getContext());
327+
MDNode *LoopWeightMD =
328+
MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
329+
SuccsSwapped ? ExitWeight1 : LoopBackWeight);
330+
LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
331+
if (HasConditionalPreHeader) {
332+
MDNode *PreHeaderWeightMD =
333+
MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
334+
SuccsSwapped ? ExitWeight0 : EnterWeight);
335+
PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
336+
}
337+
}
338+
247339
/// Rotate loop LP. Return true if the loop is rotated.
248340
///
249341
/// \param SimplifiedLatch is true if the latch was just folded into the final
@@ -363,7 +455,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
363455
// loop. Otherwise loop is not suitable for rotation.
364456
BasicBlock *Exit = BI->getSuccessor(0);
365457
BasicBlock *NewHeader = BI->getSuccessor(1);
366-
if (L->contains(Exit))
458+
bool BISuccsSwapped = L->contains(Exit);
459+
if (BISuccsSwapped)
367460
std::swap(Exit, NewHeader);
368461
assert(NewHeader && "Unable to determine new loop header");
369462
assert(L->contains(NewHeader) && !L->contains(Exit) &&
@@ -605,9 +698,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
605698
// to split as many edges.
606699
BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator());
607700
assert(PHBI->isConditional() && "Should be clone of BI condbr!");
608-
if (!isa<ConstantInt>(PHBI->getCondition()) ||
609-
PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) !=
610-
NewHeader) {
701+
const Value *Cond = PHBI->getCondition();
702+
const bool HasConditionalPreHeader =
703+
!isa<ConstantInt>(Cond) ||
704+
PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
705+
706+
updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
707+
708+
if (HasConditionalPreHeader) {
611709
// The conditional branch can't be folded, handle the general case.
612710
// Split edges as necessary to preserve LoopSimplify form.
613711

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
; RUN: opt < %s -passes='print<block-freq>' -disable-output 2>&1 | FileCheck %s --check-prefixes=BFI_BEFORE
2+
; RUN: opt < %s -passes='loop(loop-rotate),print<block-freq>' -disable-output 2>&1 | FileCheck %s --check-prefixes=BFI_AFTER
3+
; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s --check-prefixes=IR
4+
5+
@g = global i32 0
6+
7+
; We should get the same "count =" results for "outer_loop_body" and
8+
; "inner_loop_body" before and after the transformation.
9+
10+
; BFI_BEFORE-LABEL: block-frequency-info: func0
11+
; BFI_BEFORE: - entry: {{.*}} count = 1
12+
; BFI_BEFORE: - outer_loop_header: {{.*}} count = 1001
13+
; BFI_BEFORE: - outer_loop_body: {{.*}} count = 1000
14+
; BFI_BEFORE: - inner_loop_header: {{.*}} count = 4000
15+
; BFI_BEFORE: - inner_loop_body: {{.*}} count = 3000
16+
; BFI_BEFORE: - inner_loop_exit: {{.*}} count = 1000
17+
; BFI_BEFORE: - outer_loop_exit: {{.*}} count = 1
18+
19+
; BFI_AFTER-LABEL: block-frequency-info: func0
20+
; BFI_AFTER: - entry: {{.*}} count = 1
21+
; BFI_AFTER: - outer_loop_body: {{.*}} count = 1000
22+
; BFI_AFTER: - inner_loop_body: {{.*}} count = 3000
23+
; BFI_AFTER: - inner_loop_exit: {{.*}} count = 1000
24+
; BFI_AFTER: - outer_loop_exit: {{.*}} count = 1
25+
26+
; IR: inner_loop_body:
27+
; IR: br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof [[PROF_FUNC0_0:![0-9]+]]
28+
; IR: inner_loop_exit:
29+
; IR: br i1 %cmp0, label %outer_loop_body, label %outer_loop_exit, !prof [[PROF_FUNC0_1:![0-9]+]]
30+
;
31+
; A function with known loop-bounds where after loop-rotation we end with an
32+
; unconditional branch in the pre-header.
33+
define void @func0() !prof !0 {
34+
entry:
35+
br label %outer_loop_header
36+
37+
outer_loop_header:
38+
%i0 = phi i32 [0, %entry], [%i0_inc, %inner_loop_exit]
39+
%cmp0 = icmp slt i32 %i0, 1000
40+
br i1 %cmp0, label %outer_loop_body, label %outer_loop_exit, !prof !1
41+
42+
outer_loop_body:
43+
store volatile i32 %i0, ptr @g, align 4
44+
br label %inner_loop_header
45+
46+
inner_loop_header:
47+
%i1 = phi i32 [0, %outer_loop_body], [%i1_inc, %inner_loop_body]
48+
%cmp1 = icmp slt i32 %i1, 3
49+
br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof !2
50+
51+
inner_loop_body:
52+
store volatile i32 %i1, ptr @g, align 4
53+
%i1_inc = add i32 %i1, 1
54+
br label %inner_loop_header
55+
56+
inner_loop_exit:
57+
%i0_inc = add i32 %i0, 1
58+
br label %outer_loop_header
59+
60+
outer_loop_exit:
61+
ret void
62+
}
63+
64+
; BFI_BEFORE-LABEL: block-frequency-info: func1
65+
; BFI_BEFORE: - entry: {{.*}} count = 1024
66+
; BFI_BEFORE: - loop_header: {{.*}} count = 21504
67+
; BFI_BEFORE: - loop_body: {{.*}} count = 20480
68+
; BFI_BEFORE: - loop_exit: {{.*}} count = 1024
69+
70+
; BFI_AFTER-LABEL: block-frequency-info: func1
71+
; BFI_AFTER: - entry: {{.*}} count = 1024
72+
; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1024
73+
; BFI_AFTER: - loop_body: {{.*}} count = 20608
74+
; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
75+
; BFI_AFTER: - loop_exit: {{.*}} count = 1024
76+
77+
; IR: entry:
78+
; IR: br i1 %cmp1, label %loop_body.lr.ph, label %loop_exit, !prof [[PROF_FUNC1_0:![0-9]+]]
79+
80+
; IR: loop_body:
81+
; IR: br i1 %cmp, label %loop_body, label %loop_header.loop_exit_crit_edge, !prof [[PROF_FUNC1_1:![0-9]+]]
82+
83+
; A function with unknown loop-bounds so loop-rotation ends up with a
84+
; condition jump in pre-header and loop body. branch_weight shows body is
85+
; executed more often than header.
86+
define void @func1(i32 %n) !prof !3 {
87+
entry:
88+
br label %loop_header
89+
90+
loop_header:
91+
%i = phi i32 [0, %entry], [%i_inc, %loop_body]
92+
%cmp = icmp slt i32 %i, %n
93+
br i1 %cmp, label %loop_body, label %loop_exit, !prof !4
94+
95+
loop_body:
96+
store volatile i32 %i, ptr @g, align 4
97+
%i_inc = add i32 %i, 1
98+
br label %loop_header
99+
100+
loop_exit:
101+
ret void
102+
}
103+
104+
; BFI_BEFORE-LABEL: block-frequency-info: func2
105+
; BFI_BEFORE: - entry: {{.*}} count = 1024
106+
; BFI_BEFORE: - loop_header: {{.*}} count = 1056
107+
; BFI_BEFORE: - loop_body: {{.*}} count = 32
108+
; BFI_BEFORE: - loop_exit: {{.*}} count = 1024
109+
110+
; BFI_AFTER-LABEL: block-frequency-info: func2
111+
; - entry: {{.*}} count = 1024
112+
; - loop_body.lr.ph: {{.*}} count = 32
113+
; - loop_body: {{.*}} count = 32
114+
; - loop_header.loop_exit_crit_edge: {{.*}} count = 32
115+
; - loop_exit: {{.*}} count = 1024
116+
117+
; IR: entry:
118+
; IR: br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC2_0:![0-9]+]]
119+
120+
; IR: loop_body:
121+
; IR: br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC2_1:![0-9]+]]
122+
123+
; A function with unknown loop-bounds so loop-rotation ends up with a
124+
; condition jump in pre-header and loop body. Similar to `func1` but here
125+
; loop-exit count is higher than backedge count.
126+
define void @func2(i32 %n) !prof !3 {
127+
entry:
128+
br label %loop_header
129+
130+
loop_header:
131+
%i = phi i32 [0, %entry], [%i_inc, %loop_body]
132+
%cmp = icmp slt i32 %i, %n
133+
br i1 %cmp, label %loop_exit, label %loop_body, !prof !5
134+
135+
loop_body:
136+
store volatile i32 %i, ptr @g, align 4
137+
%i_inc = add i32 %i, 1
138+
br label %loop_header
139+
140+
loop_exit:
141+
ret void
142+
}
143+
144+
!0 = !{!"function_entry_count", i64 1}
145+
!1 = !{!"branch_weights", i32 1000, i32 1}
146+
!2 = !{!"branch_weights", i32 3000, i32 1000}
147+
!3 = !{!"function_entry_count", i64 1024}
148+
!4 = !{!"branch_weights", i32 40, i32 2}
149+
!5 = !{!"branch_weights", i32 10240, i32 320}
150+
151+
; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
152+
; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
153+
; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
154+
; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
155+
; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
156+
; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}

llvm/test/Transforms/LoopSimplify/merge-exits.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ define float @merge_branches_profile_metadata(ptr %pTmp1, ptr %peakWeight, i32 %
103103
; CHECK-NEXT: [[T10:%.*]] = fcmp olt float [[T4]], 2.500000e+00
104104
; CHECK-NEXT: [[T12:%.*]] = icmp sgt i64 [[TMP0]], [[INDVARS_IV_NEXT]]
105105
; CHECK-NEXT: [[OR_COND:%.*]] = and i1 [[T10]], [[T12]]
106-
; CHECK-NEXT: br i1 [[OR_COND]], label [[BB]], label [[BB1_BB3_CRIT_EDGE:%.*]], !prof [[PROF0]]
106+
; CHECK-NEXT: br i1 [[OR_COND]], label [[BB]], label [[BB1_BB3_CRIT_EDGE:%.*]], !prof [[PROF1:![0-9]+]]
107107
; CHECK: bb1.bb3_crit_edge:
108108
; CHECK-NEXT: [[T4_LCSSA:%.*]] = phi float [ [[T4]], [[BB]] ]
109109
; CHECK-NEXT: [[T9_LCSSA:%.*]] = phi float [ [[T9]], [[BB]] ]

0 commit comments

Comments
 (0)