Skip to content

[CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison #81378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H

namespace llvm {
template <typename T> class ArrayRef;
class Constant;
class CallBase;
class CastInst;
class Function;
class Instruction;
class MDNode;
class Value;

Expand All @@ -41,7 +44,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
CallBase &promoteCall(CallBase &CB, Function *Callee,
CastInst **RetBitCast = nullptr);

/// Promote the given indirect call site to conditionally call \p Callee.
/// Promote the given indirect call site to conditionally call \p Callee. The
/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
/// Callee`.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The original call site is moved into the "else" block. A clone of the
Expand All @@ -51,6 +56,22 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);

/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
///
/// This function is expected to be used on virtual calls (a subset of indirect
/// calls). \p VPtr is the virtual table address stored in the objects, and
/// \p AddressPoints contains vtable address points. A vtable address point is
/// a location inside the vtable that's referenced by vpointer in C++ objects.
///
/// TODO: sink the address-calculation instructions of indirect callee to the
/// indirect call fallback after transformation.
CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
MDNode *BranchWeights);

/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
/// success.
///
Expand All @@ -76,11 +97,11 @@ bool tryPromoteCall(CallBase &CB);

/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The "if" condition compares the call site's called value to the given
/// callee. The original call site is moved into the "else" block, and a clone
/// of the call site is placed in the "then" block. The cloned instruction is
/// returned.
/// This function creates an if-then-else structure at the location of the
/// call site. The "if" condition compares the call site's called value to
/// the given callee. The original call site is moved into the "else" block,
/// and a clone of the call site is placed in the "then" block. The cloned
/// instruction is returned.
CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);

} // end namespace llvm
Expand Down
32 changes: 28 additions & 4 deletions llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
Expand Down Expand Up @@ -188,9 +190,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The "if" condition is specified by `Cond`. The original call site is
/// moved into the "else" block, and a clone of the call site is placed in the
/// "then" block. The cloned instruction is returned.
/// site. The "if" condition is specified by `Cond`.
/// The original call site is moved into the "else" block, and a clone of the
/// call site is placed in the "then" block. The cloned instruction is returned.
///
/// For example, the call instruction below:
///
Expand Down Expand Up @@ -518,7 +520,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
Type *FormalTy = CalleeType->getParamType(ArgNo);
Type *ActualTy = Arg->getType();
if (FormalTy != ActualTy) {
auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
auto *Cast =
CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
CB.setArgOperand(ArgNo, Cast);

// Remove any incompatible attributes for the argument.
Expand Down Expand Up @@ -568,6 +571,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}

CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
MDNode *BranchWeights) {
assert(!AddressPoints.empty() && "Caller should guarantee");
IRBuilder<> Builder(&CB);
SmallVector<Value *, 2> ICmps;
for (auto &AddressPoint : AddressPoints)
ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));

// TODO: Perform tree height reduction if the number of ICmps is high.
Value *Cond = Builder.CreateOr(ICmps);

// Version the indirect call site. If Cond is true, 'NewInst' will be
// executed, otherwise the original call site will be executed.
CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);

// Promote 'NewInst' so that it directly calls the desired function.
return promoteCall(NewInst, Callee);
}

bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
Expand Down
88 changes: 88 additions & 0 deletions llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

Expand All @@ -24,6 +27,21 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
return Mod;
}

// Returns a constant representing the vtable's address point specified by the
// offset.
static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
uint32_t AddressPointOffset) {
Module &M = *VTable->getParent();
LLVMContext &Context = M.getContext();
assert(AddressPointOffset <
M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
"Out-of-bound access");

return ConstantExpr::getInBoundsGetElementPtr(
Type::getInt8Ty(Context), VTable,
llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
}

TEST(CallPromotionUtilsTest, TryPromoteCall) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
Expand Down Expand Up @@ -368,3 +386,73 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
bool IsPromoted = tryPromoteCall(*CI);
EXPECT_FALSE(IsPromoted);
}

TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
R"IR(
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6

define i32 @testfunc(ptr %d) {
entry:
%vtable = load ptr, ptr %d, !prof !7
%vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
%0 = load ptr, ptr %vfn
%call = tail call i32 %0(ptr %d), !prof !8
ret i32 %call
}

define i32 @_ZN5Base15func1Ev(ptr %this) {
entry:
ret i32 2
}

declare i32 @_ZN5Base25func2Ev(ptr)
declare i32 @_ZN5Base15func0Ev(ptr)
declare void @_ZN5Base35func3Ev(ptr)

!0 = !{i64 16, !"_ZTS5Base1"}
!1 = !{i64 48, !"_ZTS5Base2"}
!2 = !{i64 16, !"_ZTS8Derived1"}
!3 = !{i64 64, !"_ZTS5Base1"}
!4 = !{i64 40, !"_ZTS5Base2"}
!5 = !{i64 16, !"_ZTS5Base3"}
!6 = !{i64 16, !"_ZTS8Derived2"}
!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");

Function *F = M->getFunction("testfunc");
CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
ASSERT_TRUE(CI && CI->isIndirectCall());

// Create the constant and the branch weights
SmallVector<Constant *, 3> VTableAddressPoints;

for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
{"_ZTV8Derived1", 16},
{"_ZTV8Derived2", 64}})
VTableAddressPoints.push_back(getVTableAddressPointOffset(
M->getGlobalVariable(VTableName), AddressPointOffset));

MDBuilder MDB(C);
MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);

size_t OrigEntryBBSize = F->front().size();

LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());

Function *Callee = M->getFunction("_ZN5Base15func1Ev");
// Tests that promoted direct call is returned.
CallBase &DirectCB = promoteCallWithVTableCmp(
*CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
EXPECT_EQ(DirectCB.getCalledOperand(), Callee);

// Promotion inserts 3 icmp instructions and 2 or instructions, and removes
// 1 call instruction from the entry block.
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
}
Loading