Skip to content

Commit f2fa6ad

Browse files
committed
[MergeICmps] Don't reorder unmerged comparisons
MergeICmps will currently sort (by offset) all comparisons in a chain, including those that do not get merged. This is problematic in two ways: * We may end up moving the original first block into the middle of the chain, in which case the "extra work" instructions will also be in the middle of the chain, resulting in invalid IR (reported in https://reviews.llvm.org/D108782#3005583). * Reordering branches is generally not legal, because it may introduce branch on poison, which is UB (PR51845). The merging done by MergeICmps is legal as long as we assume that memcmp() works on frozen memory, but the reordering of unmerged comparisons is definitely incorrect (without inserting freeze instructions), so we should avoid it. There are easier ways to fix the first issue, but I figured it was worthwhile to do this properly to also fix the second one. What we now do is to restore the original relative order of (potentially merged) comparisons. I took the liberty of dropping the MERGEICMPS_DOT_ON functionality, because it would be more awkward to implement now (as the before and after representation is different) and it doesn't seem terribly useful nowadays. Differential Revision: https://reviews.llvm.org/D110024
1 parent 40e971a commit f2fa6ad

File tree

3 files changed

+163
-109
lines changed

3 files changed

+163
-109
lines changed

llvm/lib/Transforms/Scalar/MergeICmps.cpp

Lines changed: 78 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ class BCECmpBlock {
229229
InstructionSet BlockInsts;
230230
// The block requires splitting.
231231
bool RequireSplit = false;
232+
// Original order of this block in the chain.
233+
unsigned OrigOrder = 0;
232234

233235
private:
234236
BCECmp Cmp;
@@ -380,39 +382,83 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
380382
<< Comparison.Rhs().BaseId << " + "
381383
<< Comparison.Rhs().Offset << "\n");
382384
LLVM_DEBUG(dbgs() << "\n");
385+
Comparison.OrigOrder = Comparisons.size();
383386
Comparisons.push_back(std::move(Comparison));
384387
}
385388

386389
// A chain of comparisons.
387390
class BCECmpChain {
388-
public:
389-
BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
390-
AliasAnalysis &AA);
391-
392-
int size() const { return Comparisons_.size(); }
391+
public:
392+
using ContiguousBlocks = std::vector<BCECmpBlock>;
393393

394-
#ifdef MERGEICMPS_DOT_ON
395-
void dump() const;
396-
#endif // MERGEICMPS_DOT_ON
394+
BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
395+
AliasAnalysis &AA);
397396

398397
bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
399398
DomTreeUpdater &DTU);
400399

401-
private:
402-
static bool IsContiguous(const BCECmpBlock &First,
403-
const BCECmpBlock &Second) {
404-
return First.Lhs().BaseId == Second.Lhs().BaseId &&
405-
First.Rhs().BaseId == Second.Rhs().BaseId &&
406-
First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
407-
First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
400+
bool atLeastOneMerged() const {
401+
return any_of(MergedBlocks_,
402+
[](const auto &Blocks) { return Blocks.size() > 1; });
408403
}
409404

405+
private:
410406
PHINode &Phi_;
411-
std::vector<BCECmpBlock> Comparisons_;
407+
// The list of all blocks in the chain, grouped by contiguity.
408+
std::vector<ContiguousBlocks> MergedBlocks_;
412409
// The original entry block (before sorting);
413410
BasicBlock *EntryBlock_;
414411
};
415412

413+
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
414+
return First.Lhs().BaseId == Second.Lhs().BaseId &&
415+
First.Rhs().BaseId == Second.Rhs().BaseId &&
416+
First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
417+
First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
418+
}
419+
420+
static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) {
421+
unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
422+
for (const BCECmpBlock &Block : Blocks)
423+
MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder);
424+
return MinOrigOrder;
425+
}
426+
427+
/// Given a chain of comparison blocks, groups the blocks into contiguous
428+
/// ranges that can be merged together into a single comparison.
429+
static std::vector<BCECmpChain::ContiguousBlocks>
430+
mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
431+
std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
432+
433+
// Sort to detect continuous offsets.
434+
llvm::sort(Blocks,
435+
[](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
436+
return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
437+
std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
438+
});
439+
440+
BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr;
441+
for (BCECmpBlock &Block : Blocks) {
442+
if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
443+
MergedBlocks.emplace_back();
444+
LastMergedBlock = &MergedBlocks.back();
445+
} else {
446+
LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into "
447+
<< LastMergedBlock->back().BB->getName() << "\n");
448+
}
449+
LastMergedBlock->push_back(std::move(Block));
450+
}
451+
452+
// While we allow reordering for merging, do not reorder unmerged comparisons.
453+
// Doing so may introduce branch on poison.
454+
llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks,
455+
const BCECmpChain::ContiguousBlocks &RhsBlocks) {
456+
return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
457+
});
458+
459+
return MergedBlocks;
460+
}
461+
416462
BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
417463
AliasAnalysis &AA)
418464
: Phi_(Phi) {
@@ -492,46 +538,8 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
492538
return;
493539
}
494540
EntryBlock_ = Comparisons[0].BB;
495-
Comparisons_ = std::move(Comparisons);
496-
#ifdef MERGEICMPS_DOT_ON
497-
errs() << "BEFORE REORDERING:\n\n";
498-
dump();
499-
#endif // MERGEICMPS_DOT_ON
500-
// Reorder blocks by LHS. We can do that without changing the
501-
// semantics because we are only accessing dereferencable memory.
502-
llvm::sort(Comparisons_,
503-
[](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
504-
return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
505-
std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
506-
});
507-
#ifdef MERGEICMPS_DOT_ON
508-
errs() << "AFTER REORDERING:\n\n";
509-
dump();
510-
#endif // MERGEICMPS_DOT_ON
511-
}
512-
513-
#ifdef MERGEICMPS_DOT_ON
514-
void BCECmpChain::dump() const {
515-
errs() << "digraph dag {\n";
516-
errs() << " graph [bgcolor=transparent];\n";
517-
errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n";
518-
errs() << " edge [color=black];\n";
519-
for (size_t I = 0; I < Comparisons_.size(); ++I) {
520-
const auto &Comparison = Comparisons_[I];
521-
errs() << " \"" << I << "\" [label=\"%"
522-
<< Comparison.Lhs().Base()->getName() << " + "
523-
<< Comparison.Lhs().Offset << " == %"
524-
<< Comparison.Rhs().Base()->getName() << " + "
525-
<< Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8)
526-
<< " bytes)\"];\n";
527-
const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB);
528-
if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n";
529-
errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n";
530-
}
531-
errs() << " \"Phi\" [label=\"Phi\"];\n";
532-
errs() << "}\n\n";
541+
MergedBlocks_ = mergeBlocks(std::move(Comparisons));
533542
}
534-
#endif // MERGEICMPS_DOT_ON
535543

536544
namespace {
537545

@@ -655,47 +663,20 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
655663

656664
bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
657665
DomTreeUpdater &DTU) {
658-
assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain");
659-
// First pass to check if there is at least one merge. If not, we don't do
660-
// anything and we keep analysis passes intact.
661-
const auto AtLeastOneMerged = [this]() {
662-
for (size_t I = 1; I < Comparisons_.size(); ++I) {
663-
if (IsContiguous(Comparisons_[I - 1], Comparisons_[I]))
664-
return true;
665-
}
666-
return false;
667-
};
668-
if (!AtLeastOneMerged())
669-
return false;
670-
666+
assert(atLeastOneMerged() && "simplifying trivial BCECmpChain");
671667
LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
672668
<< EntryBlock_->getName() << "\n");
673669

674670
// Effectively merge blocks. We go in the reverse direction from the phi block
675671
// so that the next block is always available to branch to.
676-
const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num,
677-
BasicBlock *InsertBefore,
678-
BasicBlock *Next) {
679-
return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num),
680-
InsertBefore, Next, Phi_, TLI, AA, DTU);
681-
};
682672
int NumMerged = 1;
673+
BasicBlock *InsertBefore = EntryBlock_;
683674
BasicBlock *NextCmpBlock = Phi_.getParent();
684-
for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) {
685-
if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) {
686-
LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName()
687-
<< " into " << Comparisons_[I + 1].BB->getName()
688-
<< "\n");
689-
++NumMerged;
690-
} else {
691-
NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock);
692-
NumMerged = 1;
693-
}
675+
for (const auto &Blocks : reverse(MergedBlocks_)) {
676+
NumMerged += Blocks.size() - 1;
677+
InsertBefore = NextCmpBlock = mergeComparisons(
678+
Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
694679
}
695-
// Insert the entry block for the new chain before the old entry block.
696-
// If the old entry block was the function entry, this ensures that the new
697-
// entry can become the function entry.
698-
NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock);
699680

700681
// Replace the original cmp chain with the new cmp chain by pointing all
701682
// predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp
@@ -723,13 +704,16 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
723704

724705
// Delete merged blocks. This also removes incoming values in phi.
725706
SmallVector<BasicBlock *, 16> DeadBlocks;
726-
for (auto &Cmp : Comparisons_) {
727-
LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n");
728-
DeadBlocks.push_back(Cmp.BB);
707+
for (const auto &Blocks : MergedBlocks_) {
708+
for (const BCECmpBlock &Block : Blocks) {
709+
LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName()
710+
<< "\n");
711+
DeadBlocks.push_back(Block.BB);
712+
}
729713
}
730714
DeleteDeadBlocks(DeadBlocks, &DTU);
731715

732-
Comparisons_.clear();
716+
MergedBlocks_.clear();
733717
return true;
734718
}
735719

@@ -829,8 +813,8 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
829813
if (Blocks.empty()) return false;
830814
BCECmpChain CmpChain(Blocks, Phi, AA);
831815

832-
if (CmpChain.size() < 2) {
833-
LLVM_DEBUG(dbgs() << "skip: only one compare block\n");
816+
if (!CmpChain.atLeastOneMerged()) {
817+
LLVM_DEBUG(dbgs() << "skip: nothing merged\n");
834818
return false;
835819
}
836820

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -S -mergeicmps < %s | FileCheck %s
3+
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
%"struct.a::c" = type { i32, i32*, i8* }
7+
8+
; The entry block cannot be merged as the comparison is not continuous.
9+
; While it compares the highest address, it should not be moved after the
10+
; other comparisons, as that would make the allocas non-dominating.
11+
12+
define i1 @test() {
13+
; CHECK-LABEL: @test(
14+
; CHECK-NEXT: "land.lhs.true+entry":
15+
; CHECK-NEXT: [[H:%.*]] = alloca %"struct.a::c", align 8
16+
; CHECK-NEXT: [[I:%.*]] = alloca %"struct.a::c", align 8
17+
; CHECK-NEXT: call void @init(%"struct.a::c"* [[H]])
18+
; CHECK-NEXT: call void @init(%"struct.a::c"* [[I]])
19+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 1
20+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 1
21+
; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32** [[TMP0]] to i8*
22+
; CHECK-NEXT: [[CSTR2:%.*]] = bitcast i32** [[TMP1]] to i8*
23+
; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 16)
24+
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0
25+
; CHECK-NEXT: br i1 [[TMP2]], label [[LAND_RHS1:%.*]], label [[LAND_END:%.*]]
26+
; CHECK: land.rhs1:
27+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 0
28+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 0
29+
; CHECK-NEXT: [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4
30+
; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4
31+
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]]
32+
; CHECK-NEXT: br label [[LAND_END]]
33+
; CHECK: land.end:
34+
; CHECK-NEXT: [[V9:%.*]] = phi i1 [ [[TMP7]], [[LAND_RHS1]] ], [ false, %"land.lhs.true+entry" ]
35+
; CHECK-NEXT: ret i1 [[V9]]
36+
;
37+
entry:
38+
%h = alloca %"struct.a::c", align 8
39+
%i = alloca %"struct.a::c", align 8
40+
call void @init(%"struct.a::c"* %h)
41+
call void @init(%"struct.a::c"* %i)
42+
%e = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 2
43+
%v3 = load i8*, i8** %e, align 8
44+
%e2 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 2
45+
%v4 = load i8*, i8** %e2, align 8
46+
%cmp = icmp eq i8* %v3, %v4
47+
br i1 %cmp, label %land.lhs.true, label %land.end
48+
49+
land.lhs.true: ; preds = %entry
50+
%d = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 1
51+
%v5 = load i32*, i32** %d, align 8
52+
%d3 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 1
53+
%v6 = load i32*, i32** %d3, align 8
54+
%cmp4 = icmp eq i32* %v5, %v6
55+
br i1 %cmp4, label %land.rhs, label %land.end
56+
57+
land.rhs: ; preds = %land.lhs.true
58+
%j = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 0
59+
%v7 = load i32, i32* %j, align 8
60+
%j5 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 0
61+
%v8 = load i32, i32* %j5, align 8
62+
%cmp6 = icmp eq i32 %v7, %v8
63+
br label %land.end
64+
65+
land.end: ; preds = %land.rhs, %land.lhs.true, %entry
66+
%v9 = phi i1 [ false, %land.lhs.true ], [ false, %entry ], [ %cmp6, %land.rhs ]
67+
ret i1 %v9
68+
}
69+
70+
declare void @init(%"struct.a::c"*)

llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled.ll

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99

1010
define zeroext i1 @opeq1(
1111
; CHECK-LABEL: @opeq1(
12-
; CHECK-NEXT: "land.rhs.i+land.rhs.i.2":
13-
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 0
14-
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 0
15-
; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32* [[TMP0]] to i8*
16-
; CHECK-NEXT: [[CSTR3:%.*]] = bitcast i32* [[TMP1]] to i8*
17-
; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR3]], i64 8)
18-
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0
19-
; CHECK-NEXT: br i1 [[TMP2]], label [[ENTRY2:%.*]], label [[OPEQ1_EXIT:%.*]]
20-
; CHECK: entry2:
21-
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3
22-
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 2
23-
; CHECK-NEXT: [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4
24-
; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4
25-
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]]
12+
; CHECK-NEXT: entry3:
13+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 3
14+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 2
15+
; CHECK-NEXT: [[TMP2:%.*]] = load i32, i32* [[TMP0]], align 4
16+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, i32* [[TMP1]], align 4
17+
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP2]], [[TMP3]]
18+
; CHECK-NEXT: br i1 [[TMP4]], label %"land.rhs.i+land.rhs.i.2", label [[OPEQ1_EXIT:%.*]]
19+
; CHECK: "land.rhs.i+land.rhs.i.2":
20+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 0
21+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 0
22+
; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32* [[TMP5]] to i8*
23+
; CHECK-NEXT: [[CSTR2:%.*]] = bitcast i32* [[TMP6]] to i8*
24+
; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 8)
25+
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[MEMCMP]], 0
2626
; CHECK-NEXT: br i1 [[TMP7]], label [[LAND_RHS_I_31:%.*]], label [[OPEQ1_EXIT]]
2727
; CHECK: land.rhs.i.31:
2828
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3
@@ -32,7 +32,7 @@ define zeroext i1 @opeq1(
3232
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i32 [[TMP10]], [[TMP11]]
3333
; CHECK-NEXT: br label [[OPEQ1_EXIT]]
3434
; CHECK: opeq1.exit:
35-
; CHECK-NEXT: [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, [[ENTRY2]] ], [ false, %"land.rhs.i+land.rhs.i.2" ]
35+
; CHECK-NEXT: [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, %"land.rhs.i+land.rhs.i.2" ], [ false, [[ENTRY3:%.*]] ]
3636
; CHECK-NEXT: ret i1 [[TMP13]]
3737
;
3838
%S* nocapture readonly dereferenceable(16) %a,

0 commit comments

Comments
 (0)