diff --git a/llvm/include/llvm/Transforms/Scalar/GVNExpression.h b/llvm/include/llvm/Transforms/Scalar/GVNExpression.h index 2433890d0df80..71437953640c9 100644 --- a/llvm/include/llvm/Transforms/Scalar/GVNExpression.h +++ b/llvm/include/llvm/Transforms/Scalar/GVNExpression.h @@ -315,6 +315,12 @@ class CallExpression final : public MemoryExpression { return EB->getExpressionType() == ET_Call; } + bool equals(const Expression &Other) const override; + bool exactlyEquals(const Expression &Other) const override { + return Expression::exactlyEquals(Other) && + cast(Other).Call == Call; + } + // Debugging support void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index ad9b1217089d7..13bf0e7dfc87b 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -143,6 +143,8 @@ struct llvm::GVNPass::Expression { Type *type = nullptr; SmallVector varargs; + AttributeList attrs; + Expression(uint32_t o = ~2U) : opcode(o) {} bool operator==(const Expression &other) const { @@ -154,6 +156,9 @@ struct llvm::GVNPass::Expression { return false; if (varargs != other.varargs) return false; + if (!attrs.isEmpty() && !other.attrs.isEmpty() && + !attrs.intersectWith(type->getContext(), other.attrs).has_value()) + return false; return true; } @@ -364,6 +369,8 @@ GVNPass::Expression GVNPass::ValueTable::createExpr(Instruction *I) { } else if (auto *SVI = dyn_cast(I)) { ArrayRef ShuffleMask = SVI->getShuffleMask(); e.varargs.append(ShuffleMask.begin(), ShuffleMask.end()); + } else if (auto *CB = dyn_cast(I)) { + e.attrs = CB->getAttributes(); } return e; @@ -2189,16 +2196,6 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { return Changed; } -// Return true iff V1 can be replaced with V2. -static bool canBeReplacedBy(Value *V1, Value *V2) { - if (auto *CB1 = dyn_cast(V1)) - if (auto *CB2 = dyn_cast(V2)) - return CB1->getAttributes() - .intersectWith(CB2->getContext(), CB2->getAttributes()) - .has_value(); - return true; -} - static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { patchReplacementInstruction(I, Repl); I->replaceAllUsesWith(Repl); @@ -2744,7 +2741,7 @@ bool GVNPass::processInstruction(Instruction *I) { // Perform fast-path value-number based elimination of values inherited from // dominators. Value *Repl = findLeader(I->getParent(), Num); - if (!Repl || !canBeReplacedBy(I, Repl)) { + if (!Repl) { // Failure, just remember this instance for future use. LeaderTable.insert(Num, I, I->getParent()); return false; @@ -3010,7 +3007,7 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { uint32_t TValNo = VN.phiTranslate(P, CurrentBlock, ValNo, *this); Value *predV = findLeader(P, TValNo); - if (!predV || !canBeReplacedBy(CurInst, predV)) { + if (!predV) { predMap.push_back(std::make_pair(static_cast(nullptr), P)); PREPred = P; ++NumWithout; diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 6800ad51cc0a8..0cba8739441bc 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -945,6 +945,18 @@ bool StoreExpression::equals(const Expression &Other) const { return true; } +bool CallExpression::equals(const Expression &Other) const { + if (!MemoryExpression::equals(Other)) + return false; + + if (auto *RHS = dyn_cast(&Other)) + return Call->getAttributes() + .intersectWith(Call->getContext(), RHS->Call->getAttributes()) + .has_value(); + + return false; +} + // Determine if the edge From->To is a backedge bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const { return From == To || @@ -3854,16 +3866,6 @@ Value *NewGVN::findPHIOfOpsLeader(const Expression *E, return nullptr; } -// Return true iff V1 can be replaced with V2. -static bool canBeReplacedBy(Value *V1, Value *V2) { - if (auto *CB1 = dyn_cast(V1)) - if (auto *CB2 = dyn_cast(V2)) - return CB1->getAttributes() - .intersectWith(CB2->getContext(), CB2->getAttributes()) - .has_value(); - return true; -} - bool NewGVN::eliminateInstructions(Function &F) { // This is a non-standard eliminator. The normal way to eliminate is // to walk the dominator tree in order, keeping track of available @@ -3973,8 +3975,6 @@ bool NewGVN::eliminateInstructions(Function &F) { MembersLeft.insert(Member); continue; } - if (!canBeReplacedBy(Member, Leader)) - continue; LLVM_DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member << "\n"); @@ -4082,11 +4082,8 @@ bool NewGVN::eliminateInstructions(Function &F) { if (DominatingLeader != Def) { // Even if the instruction is removed, we still need to update // flags/metadata due to downstreams users of the leader. - if (!match(DefI, m_Intrinsic())) { - if (!canBeReplacedBy(DefI, DominatingLeader)) - continue; + if (!match(DefI, m_Intrinsic())) patchReplacementInstruction(DefI, DominatingLeader); - } markInstructionForDeletion(DefI); } @@ -4134,11 +4131,8 @@ bool NewGVN::eliminateInstructions(Function &F) { // original operand, as we already know we can just drop it. auto *ReplacedInst = cast(U->get()); auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst); - if (!PI || DominatingLeader != PI->OriginalOp) { - if (!canBeReplacedBy(ReplacedInst, DominatingLeader)) - continue; + if (!PI || DominatingLeader != PI->OriginalOp) patchReplacementInstruction(ReplacedInst, DominatingLeader); - } LLVM_DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for " diff --git a/llvm/test/Transforms/GVN/pr113997.ll b/llvm/test/Transforms/GVN/pr113997.ll index 35e73b1a4439b..6aebdb3a5dfac 100644 --- a/llvm/test/Transforms/GVN/pr113997.ll +++ b/llvm/test/Transforms/GVN/pr113997.ll @@ -31,3 +31,39 @@ if.else: if.then: ret i1 false } + +; Make sure we don't merge these two users of the incompatible call pair. + +define i1 @bucket2(i32 noundef %x) { +; CHECK-LABEL: define i1 @bucket2( +; CHECK-SAME: i32 noundef [[X:%.*]]) { +; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X]], 0 +; CHECK-NEXT: [[CTPOP1:%.*]] = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext [[X]]) +; CHECK-NEXT: [[CTPOP1INC:%.*]] = add i32 [[CTPOP1]], 1 +; CHECK-NEXT: [[CMP2:%.*]] = icmp samesign ult i32 [[CTPOP1INC]], 3 +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP1]], i1 [[CMP2]], i1 false +; CHECK-NEXT: br i1 [[COND]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; CHECK: [[IF_ELSE]]: +; CHECK-NEXT: [[CTPOP2:%.*]] = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 [[X]]) +; CHECK-NEXT: [[CTPOP2INC:%.*]] = add i32 [[CTPOP2]], 1 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[CTPOP2INC]], 2 +; CHECK-NEXT: ret i1 [[RES]] +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp sgt i32 %x, 0 + %ctpop1 = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext %x) + %ctpop1inc = add i32 %ctpop1, 1 + %cmp2 = icmp samesign ult i32 %ctpop1inc, 3 + %cond = select i1 %cmp1, i1 %cmp2, i1 false + br i1 %cond, label %if.then, label %if.else + +if.else: + %ctpop2 = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 %x) + %ctpop2inc = add i32 %ctpop2, 1 + %res = icmp eq i32 %ctpop2inc, 2 + ret i1 %res + +if.then: + ret i1 false +} diff --git a/llvm/test/Transforms/NewGVN/pr113997.ll b/llvm/test/Transforms/NewGVN/pr113997.ll index a919c8c304b1b..59dce09e89c88 100644 --- a/llvm/test/Transforms/NewGVN/pr113997.ll +++ b/llvm/test/Transforms/NewGVN/pr113997.ll @@ -31,3 +31,39 @@ if.else: if.then: ret i1 false } + +; Make sure we don't merge these two users of the incompatible call pair. + +define i1 @bucket2(i32 noundef %x) { +; CHECK-LABEL: define i1 @bucket2( +; CHECK-SAME: i32 noundef [[X:%.*]]) { +; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X]], 0 +; CHECK-NEXT: [[CTPOP1:%.*]] = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext [[X]]) +; CHECK-NEXT: [[CTPOP1INC:%.*]] = add i32 [[CTPOP1]], 1 +; CHECK-NEXT: [[CMP2:%.*]] = icmp samesign ult i32 [[CTPOP1INC]], 3 +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP1]], i1 [[CMP2]], i1 false +; CHECK-NEXT: br i1 [[COND]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; CHECK: [[IF_ELSE]]: +; CHECK-NEXT: [[CTPOP2:%.*]] = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 [[X]]) +; CHECK-NEXT: [[CTPOP2INC:%.*]] = add i32 [[CTPOP2]], 1 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[CTPOP2INC]], 2 +; CHECK-NEXT: ret i1 [[RES]] +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp sgt i32 %x, 0 + %ctpop1 = tail call range(i32 1, 32) i32 @llvm.ctpop.i32(i32 zeroext %x) + %ctpop1inc = add i32 %ctpop1, 1 + %cmp2 = icmp samesign ult i32 %ctpop1inc, 3 + %cond = select i1 %cmp1, i1 %cmp2, i1 false + br i1 %cond, label %if.then, label %if.else + +if.else: + %ctpop2 = tail call range(i32 0, 33) i32 @llvm.ctpop.i32(i32 %x) + %ctpop2inc = add i32 %ctpop2, 1 + %res = icmp eq i32 %ctpop2inc, 2 + ret i1 %res + +if.then: + ret i1 false +}