Skip to content

[GVN][NewGVN] Take call attributes into account in expressions #114545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/Transforms/Scalar/GVNExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ class CallExpression final : public MemoryExpression {
return EB->getExpressionType() == ET_Call;
}

bool equals(const Expression &Other) const override;
bool exactlyEquals(const Expression &Other) const override {
return Expression::exactlyEquals(Other) &&
cast<CallExpression>(Other).Call == Call;
}

// Debugging support
void printInternal(raw_ostream &OS, bool PrintEType) const override {
if (PrintEType)
Expand Down
21 changes: 9 additions & 12 deletions llvm/lib/Transforms/Scalar/GVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ struct llvm::GVNPass::Expression {
Type *type = nullptr;
SmallVector<uint32_t, 4> varargs;

AttributeList attrs;

Expression(uint32_t o = ~2U) : opcode(o) {}

bool operator==(const Expression &other) const {
Expand All @@ -154,6 +156,9 @@ struct llvm::GVNPass::Expression {
return false;
if (varargs != other.varargs)
return false;
if (!attrs.isEmpty() && !other.attrs.isEmpty() &&
!attrs.intersectWith(type->getContext(), other.attrs).has_value())
return false;
return true;
}

Expand Down Expand Up @@ -364,6 +369,8 @@ GVNPass::Expression GVNPass::ValueTable::createExpr(Instruction *I) {
} else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
ArrayRef<int> ShuffleMask = SVI->getShuffleMask();
e.varargs.append(ShuffleMask.begin(), ShuffleMask.end());
} else if (auto *CB = dyn_cast<CallBase>(I)) {
e.attrs = CB->getAttributes();
}

return e;
Expand Down Expand Up @@ -2189,16 +2196,6 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
return Changed;
}

// Return true iff V1 can be replaced with V2.
static bool canBeReplacedBy(Value *V1, Value *V2) {
if (auto *CB1 = dyn_cast<CallBase>(V1))
if (auto *CB2 = dyn_cast<CallBase>(V2))
return CB1->getAttributes()
.intersectWith(CB2->getContext(), CB2->getAttributes())
.has_value();
return true;
}

static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) {
patchReplacementInstruction(I, Repl);
I->replaceAllUsesWith(Repl);
Expand Down Expand Up @@ -2744,7 +2741,7 @@ bool GVNPass::processInstruction(Instruction *I) {
// Perform fast-path value-number based elimination of values inherited from
// dominators.
Value *Repl = findLeader(I->getParent(), Num);
if (!Repl || !canBeReplacedBy(I, Repl)) {
if (!Repl) {
// Failure, just remember this instance for future use.
LeaderTable.insert(Num, I, I->getParent());
return false;
Expand Down Expand Up @@ -3010,7 +3007,7 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {

uint32_t TValNo = VN.phiTranslate(P, CurrentBlock, ValNo, *this);
Value *predV = findLeader(P, TValNo);
if (!predV || !canBeReplacedBy(CurInst, predV)) {
if (!predV) {
predMap.push_back(std::make_pair(static_cast<Value *>(nullptr), P));
PREPred = P;
++NumWithout;
Expand Down
34 changes: 14 additions & 20 deletions llvm/lib/Transforms/Scalar/NewGVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,18 @@ bool StoreExpression::equals(const Expression &Other) const {
return true;
}

bool CallExpression::equals(const Expression &Other) const {
if (!MemoryExpression::equals(Other))
return false;

if (auto *RHS = dyn_cast<CallExpression>(&Other))
return Call->getAttributes()
.intersectWith(Call->getContext(), RHS->Call->getAttributes())
.has_value();

return false;
}

// Determine if the edge From->To is a backedge
bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const {
return From == To ||
Expand Down Expand Up @@ -3854,16 +3866,6 @@ Value *NewGVN::findPHIOfOpsLeader(const Expression *E,
return nullptr;
}

// Return true iff V1 can be replaced with V2.
static bool canBeReplacedBy(Value *V1, Value *V2) {
if (auto *CB1 = dyn_cast<CallBase>(V1))
if (auto *CB2 = dyn_cast<CallBase>(V2))
return CB1->getAttributes()
.intersectWith(CB2->getContext(), CB2->getAttributes())
.has_value();
return true;
}

bool NewGVN::eliminateInstructions(Function &F) {
// This is a non-standard eliminator. The normal way to eliminate is
// to walk the dominator tree in order, keeping track of available
Expand Down Expand Up @@ -3973,8 +3975,6 @@ bool NewGVN::eliminateInstructions(Function &F) {
MembersLeft.insert(Member);
continue;
}
if (!canBeReplacedBy(Member, Leader))
continue;

LLVM_DEBUG(dbgs() << "Found replacement " << *(Leader) << " for "
<< *Member << "\n");
Expand Down Expand Up @@ -4082,11 +4082,8 @@ bool NewGVN::eliminateInstructions(Function &F) {
if (DominatingLeader != Def) {
// Even if the instruction is removed, we still need to update
// flags/metadata due to downstreams users of the leader.
if (!match(DefI, m_Intrinsic<Intrinsic::ssa_copy>())) {
if (!canBeReplacedBy(DefI, DominatingLeader))
continue;
if (!match(DefI, m_Intrinsic<Intrinsic::ssa_copy>()))
patchReplacementInstruction(DefI, DominatingLeader);
}

markInstructionForDeletion(DefI);
}
Expand Down Expand Up @@ -4134,11 +4131,8 @@ bool NewGVN::eliminateInstructions(Function &F) {
// original operand, as we already know we can just drop it.
auto *ReplacedInst = cast<Instruction>(U->get());
auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst);
if (!PI || DominatingLeader != PI->OriginalOp) {
if (!canBeReplacedBy(ReplacedInst, DominatingLeader))
continue;
if (!PI || DominatingLeader != PI->OriginalOp)
patchReplacementInstruction(ReplacedInst, DominatingLeader);
}

LLVM_DEBUG(dbgs()
<< "Found replacement " << *DominatingLeader << " for "
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/Transforms/GVN/pr113997.ll
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,39 @@ if.else:
if.then:
ret i1 false
}

; Make sure we don't merge these two users of the incompatible call pair.

define i1 @bucket2(i32 noundef %x) {
; CHECK-LABEL: define i1 @bucket2(
; CHECK-SAME: i32 noundef [[X:%.*]]) {
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X]], 0
; CHECK-NEXT: [[CTPOP1:%.*]] = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext [[X]])
; CHECK-NEXT: [[CTPOP1INC:%.*]] = add i32 [[CTPOP1]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp samesign ult i32 [[CTPOP1INC]], 3
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP1]], i1 [[CMP2]], i1 false
; CHECK-NEXT: br i1 [[COND]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; CHECK: [[IF_ELSE]]:
; CHECK-NEXT: [[CTPOP2:%.*]] = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 [[X]])
; CHECK-NEXT: [[CTPOP2INC:%.*]] = add i32 [[CTPOP2]], 1
; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[CTPOP2INC]], 2
; CHECK-NEXT: ret i1 [[RES]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: ret i1 false
;
%cmp1 = icmp sgt i32 %x, 0
%ctpop1 = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext %x)
%ctpop1inc = add i32 %ctpop1, 1
%cmp2 = icmp samesign ult i32 %ctpop1inc, 3
%cond = select i1 %cmp1, i1 %cmp2, i1 false
br i1 %cond, label %if.then, label %if.else

if.else:
%ctpop2 = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 %x)
%ctpop2inc = add i32 %ctpop2, 1
%res = icmp eq i32 %ctpop2inc, 2
ret i1 %res

if.then:
ret i1 false
}
36 changes: 36 additions & 0 deletions llvm/test/Transforms/NewGVN/pr113997.ll
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,39 @@ if.else:
if.then:
ret i1 false
}

; Make sure we don't merge these two users of the incompatible call pair.

define i1 @bucket2(i32 noundef %x) {
; CHECK-LABEL: define i1 @bucket2(
; CHECK-SAME: i32 noundef [[X:%.*]]) {
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X]], 0
; CHECK-NEXT: [[CTPOP1:%.*]] = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext [[X]])
; CHECK-NEXT: [[CTPOP1INC:%.*]] = add i32 [[CTPOP1]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp samesign ult i32 [[CTPOP1INC]], 3
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP1]], i1 [[CMP2]], i1 false
; CHECK-NEXT: br i1 [[COND]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; CHECK: [[IF_ELSE]]:
; CHECK-NEXT: [[CTPOP2:%.*]] = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 [[X]])
; CHECK-NEXT: [[CTPOP2INC:%.*]] = add i32 [[CTPOP2]], 1
; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[CTPOP2INC]], 2
; CHECK-NEXT: ret i1 [[RES]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: ret i1 false
;
%cmp1 = icmp sgt i32 %x, 0
%ctpop1 = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext %x)
%ctpop1inc = add i32 %ctpop1, 1
%cmp2 = icmp samesign ult i32 %ctpop1inc, 3
%cond = select i1 %cmp1, i1 %cmp2, i1 false
br i1 %cond, label %if.then, label %if.else

if.else:
%ctpop2 = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 %x)
%ctpop2inc = add i32 %ctpop2, 1
%res = icmp eq i32 %ctpop2inc, 2
ret i1 %res

if.then:
ret i1 false
}
Loading