diff --git a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h index 63e1ad043d49f..efa9c4bdc496e 100644 --- a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -106,13 +106,16 @@ class DeadArgumentEliminationPass UseMap Uses; using LiveSet = std::set; - using LiveFuncSet = std::set; + using FuncSet = std::set; /// This set contains all values that have been determined to be live. LiveSet LiveValues; - /// This set contains all values that are cannot be changed in any way. - LiveFuncSet LiveFunctions; + /// This set contains all functions that cannot be changed in any way. + FuncSet FrozenFunctions; + + /// This set contains all functions that cannot change return type; + FuncSet FrozenRetTyFunctions; using UseVector = SmallVector; @@ -131,12 +134,13 @@ class DeadArgumentEliminationPass void markValue(const RetOrArg &RA, Liveness L, const UseVector &MaybeLiveUses); void markLive(const RetOrArg &RA); - void markLive(const Function &F); + void markFrozen(const Function &F); + void markRetTyFrozen(const Function &F); + bool markFnOrRetTyFrozenOnMusttail(const Function &F); void propagateLiveness(const RetOrArg &RA); bool removeDeadStuffFromFunction(Function *F); bool deleteDeadVarargs(Function &F); bool removeDeadArgumentsFromCallers(Function &F); - void propagateVirtMustcallLiveness(const Module &M); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index ed93b4491c50e..2e2687a5ff6e3 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -87,11 +87,6 @@ class DAE : public ModulePass { virtual bool shouldHackArguments() const { return false; } }; -bool isMustTailCalleeAnalyzable(const CallBase &CB) { - assert(CB.isMustTailCall()); - return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration(); -} - } // end anonymous namespace char DAE::ID = 0; @@ -280,7 +275,7 @@ bool DeadArgumentEliminationPass::removeDeadArgumentsFromCallers(Function &F) { // they are fully alive (e.g., called indirectly) and except for the fragile // (variadic) ones. In these cases, we may still be able to improve their // statically known call sites. - if ((F.hasLocalLinkage() && !LiveFunctions.count(&F)) && + if ((F.hasLocalLinkage() && !FrozenFunctions.count(&F)) && !F.getFunctionType()->isVarArg()) return false; @@ -496,7 +491,7 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { // particular register and memory layout. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) || F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) { - markLive(F); + markFrozen(F); return; } @@ -504,7 +499,7 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { // otherwise rely on the frame layout in a way that this analysis will not // see. if (F.hasFnAttribute(Attribute::Naked)) { - markLive(F); + markFrozen(F); return; } @@ -522,29 +517,17 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { // MaybeLive. Initialized to a list of RetCount empty lists. RetUses MaybeLiveRetUses(RetCount); - bool HasMustTailCalls = false; for (const BasicBlock &BB : F) { - // If we have any returns of `musttail` results - the signature can't - // change - if (const auto *TC = BB.getTerminatingMustTailCall()) { - HasMustTailCalls = true; - // In addition, if the called function is not locally defined (or unknown, - // if this is an indirect call), we can't change the callsite and thus - // can't change this function's signature either. - if (!isMustTailCalleeAnalyzable(*TC)) { - markLive(F); + if (BB.getTerminatingMustTailCall()) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() + << " has musttail calls\n"); + if (markFnOrRetTyFrozenOnMusttail(F)) return; - } } } - if (HasMustTailCalls) { - LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() - << " has musttail calls\n"); - } - if (!F.hasLocalLinkage() && (!ShouldHackArguments || F.isIntrinsic())) { - markLive(F); + markFrozen(F); return; } @@ -555,8 +538,6 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { // of them turn out to be live. unsigned NumLiveRetVals = 0; - bool HasMustTailCallers = false; - // Loop all uses of the function. for (const Use &U : F.uses()) { // If the function is PASSED IN as an argument, its address has been @@ -564,14 +545,16 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { const auto *CB = dyn_cast(U.getUser()); if (!CB || !CB->isCallee(&U) || CB->getFunctionType() != F.getFunctionType()) { - markLive(F); + markFrozen(F); return; } - // The number of arguments for `musttail` call must match the number of - // arguments of the caller - if (CB->isMustTailCall()) - HasMustTailCallers = true; + if (CB->isMustTailCall()) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() + << " has musttail callers\n"); + if (markFnOrRetTyFrozenOnMusttail(F)) + return; + } // If we end up here, we are looking at a direct call to our function. @@ -610,11 +593,6 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { } } - if (HasMustTailCallers) { - LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() - << " has musttail callers\n"); - } - // Now we've inspected all callers, record the liveness of our return values. for (unsigned Ri = 0; Ri != RetCount; ++Ri) markValue(createRet(&F, Ri), RetValLiveness[Ri], MaybeLiveRetUses[Ri]); @@ -628,19 +606,12 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { for (Function::const_arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; ++AI, ++ArgI) { Liveness Result; - if (F.getFunctionType()->isVarArg() || HasMustTailCallers || - HasMustTailCalls) { + if (F.getFunctionType()->isVarArg()) { // Variadic functions will already have a va_arg function expanded inside // them, making them potentially very sensitive to ABI changes resulting // from removing arguments entirely, so don't. For example AArch64 handles // register and stack HFAs very differently, and this is reflected in the // IR which has already been generated. - // - // `musttail` calls to this function restrict argument removal attempts. - // The signature of the caller must match the signature of the function. - // - // `musttail` calls in this function prevents us from changing its - // signature Result = Live; } else { // See what the effect of this use is (recording any uses that cause @@ -680,14 +651,30 @@ void DeadArgumentEliminationPass::markValue(const RetOrArg &RA, Liveness L, } } +/// Return true if we freeze the whole function. +/// If the calling convention is not swifttailcc or tailcc, the caller and +/// callee of musttail must have exactly the same signature. Otherwise we +/// only needs to guarantee they have the same return type. +bool DeadArgumentEliminationPass::markFnOrRetTyFrozenOnMusttail( + const Function &F) { + if (F.getCallingConv() != CallingConv::SwiftTail || + F.getCallingConv() != CallingConv::Tail) { + markFrozen(F); + return true; + } else { + markRetTyFrozen(F); + return false; + } +} + /// Mark the given Function as alive, meaning that it cannot be changed in any /// way. Additionally, mark any values that are used as this function's /// parameters or by its return values (according to Uses) live as well. -void DeadArgumentEliminationPass::markLive(const Function &F) { - LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Intrinsically live fn: " +void DeadArgumentEliminationPass::markFrozen(const Function &F) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - frozen fn: " << F.getName() << "\n"); - // Mark the function as live. - LiveFunctions.insert(&F); + // Mark the function as frozen. + FrozenFunctions.insert(&F); // Mark all arguments as live. for (unsigned ArgI = 0, E = F.arg_size(); ArgI != E; ++ArgI) propagateLiveness(createArg(&F, ArgI)); @@ -696,6 +683,12 @@ void DeadArgumentEliminationPass::markLive(const Function &F) { propagateLiveness(createRet(&F, Ri)); } +void DeadArgumentEliminationPass::markRetTyFrozen(const Function &F) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - frozen return type fn: " + << F.getName() << "\n"); + FrozenRetTyFunctions.insert(&F); +} + /// Mark the given return value or argument as live. Additionally, mark any /// values that are used by this value (according to Uses) live as well. void DeadArgumentEliminationPass::markLive(const RetOrArg &RA) { @@ -710,7 +703,7 @@ void DeadArgumentEliminationPass::markLive(const RetOrArg &RA) { } bool DeadArgumentEliminationPass::isLive(const RetOrArg &RA) { - return LiveFunctions.count(RA.F) || LiveValues.count(RA); + return FrozenFunctions.count(RA.F) || LiveValues.count(RA); } /// Given that RA is a live value, propagate it's liveness to any other values @@ -734,8 +727,8 @@ void DeadArgumentEliminationPass::propagateLiveness(const RetOrArg &RA) { /// Transform the function and all the callees of the function to not have these /// arguments and return values. bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { - // Don't modify fully live functions - if (LiveFunctions.count(F)) + // Don't modify frozen functions + if (FrozenFunctions.count(F)) return false; // Start by computing a new prototype for the function, which is the same as @@ -807,7 +800,8 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // performance win, so the second option can just be used always for now. // // This should be revisited if 'returned' is ever applied more liberally. - if (RetTy->isVoidTy() || HasLiveReturnedArg) { + if (RetTy->isVoidTy() || HasLiveReturnedArg || + FrozenRetTyFunctions.count(F)) { NRetTy = RetTy; } else { // Look at each of the original return values individually. @@ -1109,26 +1103,6 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { return true; } -void DeadArgumentEliminationPass::propagateVirtMustcallLiveness( - const Module &M) { - // If a function was marked "live", and it has musttail callers, they in turn - // can't change either. - LiveFuncSet NewLiveFuncs(LiveFunctions); - while (!NewLiveFuncs.empty()) { - LiveFuncSet Temp; - for (const auto *F : NewLiveFuncs) - for (const auto *U : F->users()) - if (const auto *CB = dyn_cast(U)) - if (CB->isMustTailCall()) - if (!LiveFunctions.count(CB->getParent()->getParent())) - Temp.insert(CB->getParent()->getParent()); - NewLiveFuncs.clear(); - NewLiveFuncs.insert(Temp.begin(), Temp.end()); - for (const auto *F : Temp) - markLive(*F); - } -} - PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; @@ -1149,8 +1123,6 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, for (auto &F : M) surveyFunction(F); - propagateVirtMustcallLiveness(M); - // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). diff --git a/llvm/test/Transforms/DeadArgElim/musttail-caller.ll b/llvm/test/Transforms/DeadArgElim/musttail-caller.ll index 4549ab41fb8ad..01666a5f9ad04 100644 --- a/llvm/test/Transforms/DeadArgElim/musttail-caller.ll +++ b/llvm/test/Transforms/DeadArgElim/musttail-caller.ll @@ -1,16 +1,22 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ; RUN: opt -passes=deadargelim -S < %s | FileCheck %s ; PR36441 ; Dead arguments should not be removed in presence of `musttail` calls. -; CHECK-LABEL: define internal void @test(i32 %a, i32 %b) -; CHECK: musttail call void @foo(i32 %a, i32 0) -; FIXME: we should replace those with `undef`s define internal void @test(i32 %a, i32 %b) { +; CHECK-LABEL: define internal void @test( +; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) { +; CHECK-NEXT: musttail call void @foo(i32 poison, i32 poison) +; CHECK-NEXT: ret void +; musttail call void @foo(i32 %a, i32 0) ret void } -; CHECK-LABEL: define internal void @foo(i32 %a, i32 %b) define internal void @foo(i32 %a, i32 %b) { +; CHECK-LABEL: define internal void @foo( +; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) { +; CHECK-NEXT: ret void +; ret void } diff --git a/llvm/test/Transforms/DeadArgElim/musttail-verifier.ll b/llvm/test/Transforms/DeadArgElim/musttail-verifier.ll new file mode 100644 index 0000000000000..b1be1328940ed --- /dev/null +++ b/llvm/test/Transforms/DeadArgElim/musttail-verifier.ll @@ -0,0 +1,66 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; Testcases comes from PR126817 and PR107569 +; See PR54964 and langref for more information about how llvm deal with musttail currently +; RUN: opt -passes=deadargelim -S < %s | FileCheck %s + +define i64 @A() { +; CHECK-LABEL: define i64 @A() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[V2660:%.*]] = musttail call i64 @B() +; CHECK-NEXT: ret i64 [[V2660]] +; +entry: + %v2660 = musttail call i64 @B() + ret i64 %v2660 +} + +define internal i64 @B() { +; CHECK-LABEL: define internal i64 @B() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: ret i64 0 +; +entry: + ret i64 0 +} + +define internal i64 @C() { +; CHECK-LABEL: define internal i64 @C() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[V30543:%.*]] = musttail call i64 @B() +; CHECK-NEXT: ret i64 [[V30543]] +; +entry: + %v30543 = musttail call i64 @B() + ret i64 %v30543 +} + +%struct.S = type { double } + +define internal %struct.S @F38() { +; CHECK-LABEL: define internal %struct.S @F38() { +; CHECK-NEXT: ret [[STRUCT_S:%.*]] zeroinitializer +; + ret %struct.S { double 0.0 } +} + +define internal %struct.S @F36() { +; CHECK-LABEL: define internal %struct.S @F36() { +; CHECK-NEXT: [[TMP1:%.*]] = alloca [[STRUCT_S:%.*]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = musttail call [[STRUCT_S]] @[[F38:[a-zA-Z0-9_$\"\\.-]*[a-zA-Z_$\"\\.-][a-zA-Z0-9_$\"\\.-]*]]() +; CHECK-NEXT: ret [[STRUCT_S]] [[TMP2]] +; + %1 = alloca %struct.S, align 8 + %3 = musttail call %struct.S @F38() + ret %struct.S %3 +} + +define double @foo() { +; CHECK-LABEL: define double @foo() { +; CHECK-NEXT: [[TMP1:%.*]] = call [[STRUCT_S:%.*]] @[[F36:[a-zA-Z0-9_$\"\\.-]*[a-zA-Z_$\"\\.-][a-zA-Z0-9_$\"\\.-]*]]() +; CHECK-NEXT: [[TMP2:%.*]] = extractvalue [[STRUCT_S]] [[TMP1]], 0 +; CHECK-NEXT: ret double [[TMP2]] +; + %3 = call %struct.S @F36() + %5 = extractvalue %struct.S %3, 0 + ret double %5 +}