Skip to content

Commit 6722e83

Browse files
committed
ssaupdaterbulk_add_phi_optimization
1 parent ae46b9e commit 6722e83

File tree

3 files changed

+316
-2
lines changed

3 files changed

+316
-2
lines changed

llvm/include/llvm/Transforms/Utils/SSAUpdaterBulk.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef LLVM_TRANSFORMS_UTILS_SSAUPDATERBULK_H
1414
#define LLVM_TRANSFORMS_UTILS_SSAUPDATERBULK_H
1515

16-
#include "llvm/ADT/DenseMap.h"
1716
#include "llvm/ADT/StringRef.h"
1817
#include "llvm/IR/PredIteratorCache.h"
1918
#include "llvm/Support/Compiler.h"
@@ -79,6 +78,10 @@ class SSAUpdaterBulk {
7978
LLVM_ABI void
8079
RewriteAllUses(DominatorTree *DT,
8180
SmallVectorImpl<PHINode *> *InsertedPHIs = nullptr);
81+
82+
/// Rewrite all uses and simplify the inserted PHI nodes.
83+
/// Use this method to preserve behavior when replacing SSAUpdater.
84+
void RewriteAndOptimizeAllUses(DominatorTree &DT);
8285
};
8386

8487
} // end namespace llvm

llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/Transforms/Utils/SSAUpdaterBulk.h"
14+
#include "llvm/Analysis/InstructionSimplify.h"
1415
#include "llvm/Analysis/IteratedDominanceFrontier.h"
1516
#include "llvm/IR/BasicBlock.h"
1617
#include "llvm/IR/Dominators.h"
1718
#include "llvm/IR/IRBuilder.h"
18-
#include "llvm/IR/Instructions.h"
1919
#include "llvm/IR/Use.h"
2020
#include "llvm/IR/Value.h"
2121

@@ -222,3 +222,94 @@ void SSAUpdaterBulk::RewriteAllUses(DominatorTree *DT,
222222
}
223223
}
224224
}
225+
226+
// Perform a single pass of simplification over the worklist of PHIs.
227+
static void simplifyPass(MutableArrayRef<PHINode *> Worklist,
228+
const DataLayout &DL) {
229+
for (PHINode *&PHI : Worklist) {
230+
if (Value *Simplified = simplifyInstruction(PHI, DL)) {
231+
PHI->replaceAllUsesWith(Simplified);
232+
PHI->eraseFromParent();
233+
PHI = nullptr; // Mark as removed.
234+
}
235+
}
236+
}
237+
238+
#ifndef NDEBUG // Should this be under EXPENSIVE_CHECKS?
239+
// New PHI nodes should not reference one another but they may reference
240+
// themselves or existing PHI nodes, and existing PHI nodes may reference new
241+
// PHI nodes.
242+
static bool
243+
PHIAreRefEachOther(const iterator_range<BasicBlock::phi_iterator> &NewPHIs) {
244+
SmallPtrSet<PHINode *, 8> NewPHISet;
245+
for (PHINode &PN : NewPHIs)
246+
NewPHISet.insert(&PN);
247+
for (PHINode &PHI : NewPHIs) {
248+
for (Value *V : PHI.incoming_values()) {
249+
PHINode *IncPHI = dyn_cast<PHINode>(V);
250+
if (IncPHI && IncPHI != &PHI && NewPHISet.contains(IncPHI))
251+
return true;
252+
}
253+
}
254+
return false;
255+
}
256+
#endif
257+
258+
bool EliminateNewDuplicatePHINodes(BasicBlock *BB,
259+
BasicBlock::phi_iterator FirstExistingPN) {
260+
261+
auto NewPHIs = make_range(BB->phis().begin(), FirstExistingPN);
262+
assert(!PHIAreRefEachOther(NewPHIs));
263+
264+
auto ReplaceIfIdentical = [](PHINode &PHI, PHINode &ReplPHI) {
265+
if (!PHI.isIdenticalToWhenDefined(&ReplPHI))
266+
return false;
267+
PHI.replaceAllUsesWith(&ReplPHI);
268+
PHI.eraseFromParent();
269+
return true;
270+
};
271+
272+
// Deduplicate new PHIs first to reduce the number of comparisons on the
273+
// following new -> existing pass.
274+
bool Changed = false;
275+
for (auto I = BB->phis().begin(); I != FirstExistingPN; ++I) {
276+
for (auto J = std::next(I); J != FirstExistingPN;) {
277+
Changed |= ReplaceIfIdentical(*J++, *I);
278+
}
279+
}
280+
281+
// Iterate over existing PHIs and replace identical new PHIs.
282+
for (PHINode &ExistingPHI : make_range(FirstExistingPN, BB->phis().end())) {
283+
auto I = BB->phis().begin();
284+
assert(I != FirstExistingPN); // Should be at least one new PHI.
285+
do {
286+
Changed |= ReplaceIfIdentical(*I++, ExistingPHI);
287+
} while (I != FirstExistingPN);
288+
if (BB->phis().begin() == FirstExistingPN)
289+
return Changed;
290+
}
291+
return Changed;
292+
}
293+
294+
static void deduplicatePass(ArrayRef<PHINode *> Worklist) {
295+
SmallDenseMap<BasicBlock *, unsigned> BBs;
296+
for (PHINode *PHI : Worklist) {
297+
if (PHI)
298+
++BBs[PHI->getParent()];
299+
}
300+
301+
for (auto [BB, NumNewPHIs] : BBs) {
302+
auto FirstExistingPN = std::next(BB->phis().begin(), NumNewPHIs);
303+
EliminateNewDuplicatePHINodes(BB, FirstExistingPN);
304+
}
305+
}
306+
307+
void SSAUpdaterBulk::RewriteAndOptimizeAllUses(DominatorTree &DT) {
308+
SmallVector<PHINode *, 4> PHIs;
309+
RewriteAllUses(&DT, &PHIs);
310+
if (PHIs.empty())
311+
return;
312+
313+
simplifyPass(PHIs, PHIs.front()->getParent()->getDataLayout());
314+
deduplicatePass(PHIs);
315+
}

llvm/unittests/Transforms/Utils/SSAUpdaterBulkTest.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,223 @@ TEST(SSAUpdaterBulk, TwoBBLoop) {
308308
EXPECT_EQ(Phi->getIncomingValueForBlock(Entry), ConstantInt::get(I32Ty, 0));
309309
EXPECT_EQ(Phi->getIncomingValueForBlock(Loop), I);
310310
}
311+
312+
TEST(SSAUpdaterBulk, SimplifyPHIs) {
313+
const char *IR = R"(
314+
define void @main(i32 %val, i1 %cond) {
315+
entry:
316+
br i1 %cond, label %left, label %right
317+
left:
318+
%add = add i32 %val, 1
319+
br label %exit
320+
right:
321+
%sub = sub i32 %val, 1
322+
br label %exit
323+
exit:
324+
%phi = phi i32 [ %sub, %right ], [ %add, %left ]
325+
%cmp = icmp slt i32 0, 42
326+
ret void
327+
}
328+
)";
329+
330+
llvm::LLVMContext Context;
331+
llvm::SMDiagnostic Err;
332+
std::unique_ptr<llvm::Module> M = llvm::parseAssemblyString(IR, Err, Context);
333+
ASSERT_NE(M, nullptr) << "Failed to parse IR: " << Err.getMessage();
334+
335+
Function *F = M->getFunction("main");
336+
auto *Entry = &F->getEntryBlock();
337+
auto *Left = Entry->getTerminator()->getSuccessor(0);
338+
auto *Right = Entry->getTerminator()->getSuccessor(1);
339+
auto *Exit = Left->getSingleSuccessor();
340+
auto *Val = &*F->arg_begin();
341+
auto *Phi = &Exit->front();
342+
auto *Cmp = &*std::next(Exit->begin());
343+
auto *Add = &Left->front();
344+
auto *Sub = &Right->front();
345+
346+
SSAUpdaterBulk Updater;
347+
Type *I32Ty = Type::getInt32Ty(Context);
348+
349+
// Use %val directly instead of creating a phi.
350+
unsigned ValVar = Updater.AddVariable("Val", I32Ty);
351+
Updater.AddAvailableValue(ValVar, Left, Val);
352+
Updater.AddAvailableValue(ValVar, Right, Val);
353+
Updater.AddUse(ValVar, &Cmp->getOperandUse(0));
354+
355+
// Use existing %phi for %add and %sub values.
356+
unsigned AddSubVar = Updater.AddVariable("AddSub", I32Ty);
357+
Updater.AddAvailableValue(AddSubVar, Left, Add);
358+
Updater.AddAvailableValue(AddSubVar, Right, Sub);
359+
Updater.AddUse(AddSubVar, &Cmp->getOperandUse(1));
360+
361+
auto ExitSizeBefore = Exit->size();
362+
DominatorTree DT(*F);
363+
Updater.RewriteAndOptimizeAllUses(DT);
364+
365+
// Output for Exit->dump():
366+
// exit: ; preds = %right, %left
367+
// %phi = phi i32 [ %sub, %right ], [ %add, %left ]
368+
// %cmp = icmp slt i32 %val, %phi
369+
// ret void
370+
371+
ASSERT_EQ(Exit->size(), ExitSizeBefore);
372+
ASSERT_EQ(&Exit->front(), Phi);
373+
EXPECT_EQ(Val, Cmp->getOperand(0));
374+
EXPECT_EQ(Phi, Cmp->getOperand(1));
375+
}
376+
377+
bool EliminateNewDuplicatePHINodes(BasicBlock *BB,
378+
BasicBlock::phi_iterator FirstExistingPN);
379+
380+
// Helper to run both versions on the same input.
381+
static void RunEliminateNewDuplicatePHINode(
382+
const char *AsmText,
383+
std::function<void(BasicBlock &,
384+
bool(BasicBlock *BB, BasicBlock::phi_iterator))>
385+
Check) {
386+
LLVMContext C;
387+
388+
SMDiagnostic Err;
389+
std::unique_ptr<Module> M = parseAssemblyString(AsmText, Err, C);
390+
if (!M) {
391+
Err.print("UtilsTests", errs());
392+
return;
393+
}
394+
395+
Function *F = M->getFunction("main");
396+
auto BBIt = std::find_if(F->begin(), F->end(), [](const BasicBlock &Block) {
397+
return Block.getName() == "testbb";
398+
});
399+
ASSERT_NE(BBIt, F->end());
400+
Check(*BBIt, EliminateNewDuplicatePHINodes);
401+
}
402+
403+
static BasicBlock::phi_iterator getPhiIt(BasicBlock &BB, unsigned Idx) {
404+
return std::next(BB.phis().begin(), Idx);
405+
}
406+
407+
static PHINode *getPhi(BasicBlock &BB, unsigned Idx) {
408+
return &*getPhiIt(BB, Idx);
409+
}
410+
411+
static int getNumPHIs(BasicBlock &BB) {
412+
return std::distance(BB.phis().begin(), BB.phis().end());
413+
}
414+
415+
TEST(SSAUpdaterBulk, EliminateNewDuplicatePHINodes_OrderExisting) {
416+
RunEliminateNewDuplicatePHINode(R"(
417+
define void @main() {
418+
entry:
419+
br label %testbb
420+
testbb:
421+
%np0 = phi i32 [ 1, %entry ]
422+
%np1 = phi i32 [ 1, %entry ]
423+
%ep0 = phi i32 [ 1, %entry ]
424+
%ep1 = phi i32 [ 1, %entry ]
425+
%u = add i32 %np0, %np1
426+
ret void
427+
}
428+
)", [](BasicBlock &BB, auto *ENDPN) {
429+
AssertingVH<PHINode> EP0 = getPhi(BB, 2);
430+
AssertingVH<PHINode> EP1 = getPhi(BB, 3);
431+
EXPECT_TRUE(ENDPN(&BB, getPhiIt(BB, 2)));
432+
// Expected:
433+
// %ep0 = phi i32 [ 1, %entry ]
434+
// %ep1 = phi i32 [ 1, %entry ]
435+
// %u = add i32 %ep0, %ep0
436+
EXPECT_EQ(getNumPHIs(BB), 2);
437+
Instruction &Add = *BB.getFirstNonPHIIt();
438+
EXPECT_EQ(Add.getOperand(0), EP0);
439+
EXPECT_EQ(Add.getOperand(1), EP0);
440+
(void)EP1; // Avoid "unused" warning.
441+
});
442+
}
443+
444+
TEST(SSAUpdaterBulk, EliminateNewDuplicatePHINodes_OrderNew) {
445+
RunEliminateNewDuplicatePHINode(R"(
446+
define void @main() {
447+
entry:
448+
br label %testbb
449+
testbb:
450+
%np0 = phi i32 [ 1, %entry ]
451+
%np1 = phi i32 [ 1, %entry ]
452+
%ep0 = phi i32 [ 2, %entry ]
453+
%ep1 = phi i32 [ 2, %entry ]
454+
%u = add i32 %np0, %np1
455+
ret void
456+
}
457+
)", [](BasicBlock &BB, auto *ENDPN) {
458+
AssertingVH<PHINode> NP0 = getPhi(BB, 0);
459+
AssertingVH<PHINode> EP0 = getPhi(BB, 2);
460+
AssertingVH<PHINode> EP1 = getPhi(BB, 3);
461+
EXPECT_TRUE(ENDPN(&BB, getPhiIt(BB, 2)));
462+
// Expected:
463+
// %np0 = phi i32 [ 1, %entry ]
464+
// %ep0 = phi i32 [ 2, %entry ]
465+
// %ep1 = phi i32 [ 2, %entry ]
466+
// %u = add i32 %np0, %np0
467+
EXPECT_EQ(getNumPHIs(BB), 3);
468+
Instruction &Add = *BB.getFirstNonPHIIt();
469+
EXPECT_EQ(Add.getOperand(0), NP0);
470+
EXPECT_EQ(Add.getOperand(1), NP0);
471+
(void)EP0;
472+
(void)EP1; // Avoid "unused" warning.
473+
});
474+
}
475+
476+
TEST(SSAUpdaterBulk, EliminateNewDuplicatePHINodes_NewRefExisting) {
477+
RunEliminateNewDuplicatePHINode(R"(
478+
define void @main() {
479+
entry:
480+
br label %testbb
481+
testbb:
482+
%np0 = phi i32 [ 1, %entry ], [ %ep0, %testbb ]
483+
%np1 = phi i32 [ 1, %entry ], [ %ep1, %testbb ]
484+
%ep0 = phi i32 [ 1, %entry ], [ %ep0, %testbb ]
485+
%ep1 = phi i32 [ 1, %entry ], [ %ep1, %testbb ]
486+
%u = add i32 %np0, %np1
487+
br label %testbb
488+
}
489+
)", [](BasicBlock &BB, auto *ENDPN) {
490+
AssertingVH<PHINode> EP0 = getPhi(BB, 2);
491+
AssertingVH<PHINode> EP1 = getPhi(BB, 3);
492+
EXPECT_TRUE(ENDPN(&BB, getPhiIt(BB, 2)));
493+
// Expected:
494+
// %ep0 = phi i32 [ 1, %entry ], [ %ep0, %testbb ]
495+
// %ep1 = phi i32 [ 1, %entry ], [ %ep1, %testbb ]
496+
// %u = add i32 %ep0, %ep1
497+
EXPECT_EQ(getNumPHIs(BB), 2);
498+
Instruction &Add = *BB.getFirstNonPHIIt();
499+
EXPECT_EQ(Add.getOperand(0), EP0);
500+
EXPECT_EQ(Add.getOperand(1), EP1);
501+
});
502+
}
503+
504+
TEST(SSAUpdaterBulk, EliminateNewDuplicatePHINodes_ExistingRefNew) {
505+
RunEliminateNewDuplicatePHINode(R"(
506+
define void @main() {
507+
entry:
508+
br label %testbb
509+
testbb:
510+
%np0 = phi i32 [ 1, %entry ], [ %np0, %testbb ]
511+
%np1 = phi i32 [ 1, %entry ], [ %np1, %testbb ]
512+
%ep0 = phi i32 [ 1, %entry ], [ %np0, %testbb ]
513+
%ep1 = phi i32 [ 1, %entry ], [ %np1, %testbb ]
514+
%u = add i32 %np0, %np1
515+
br label %testbb
516+
}
517+
)", [](BasicBlock &BB, auto *ENDPN) {
518+
AssertingVH<PHINode> EP0 = getPhi(BB, 2);
519+
AssertingVH<PHINode> EP1 = getPhi(BB, 3);
520+
EXPECT_TRUE(ENDPN(&BB, getPhiIt(BB, 2)));
521+
// Expected:
522+
// %ep0 = phi i32 [ 1, %entry ], [ %ep0, %testbb ]
523+
// %ep1 = phi i32 [ 1, %entry ], [ %ep1, %testbb ]
524+
// %u = add i32 %ep0, %ep1
525+
EXPECT_EQ(getNumPHIs(BB), 2);
526+
Instruction &Add = *BB.getFirstNonPHIIt();
527+
EXPECT_EQ(Add.getOperand(0), EP0);
528+
EXPECT_EQ(Add.getOperand(1), EP1);
529+
});
530+
}

0 commit comments

Comments
 (0)