Skip to content

Commit 5d3f296

Browse files
[CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison (#81378)
* Given the code sequence ``` bb: %vtable = load ptr, ptr %d, !prof !8 %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1 %1 = load ptr, ptr %vfn %call = tail call i32 %1(ptr %d), !prof !9 ``` The transformation looks like ``` bb: %vtable = load ptr, ptr %d, align 8 %vfn = getelementptr inbounds i8, ptr %vtable, i64 8 <-- Inst 1 %func-addr = load ptr, ptr %vfn, align 8 <-- Inst 2 # compare loaded pointers with address point of vtables %1 = icmp eq ptr %vtable, getelementptr inbounds (i8, ptr @_ZTV<VTable>, i32 16) br i1 %1, label %if.true.direct_targ, label %if.false.orig_indirect, !prof !18 if.true.direct_targ: ; preds = %bb %2 = tail call i32 @<direct-call>(ptr nonnull %d) br label %if.end.icp if.false.orig_indirect: ; preds = %bb %call = tail call i32 %func-addr(ptr nonnull %d) br label %if.end.icp if.end.icp: ; preds = %if.false.orig_indirect, %if.true.direct_targ %4 = phi i32 [ %call, %if.false.orig_indirect ], [ %2, %if.true.direct_targ ] ``` It's intentional that `Inst 1` and `Inst2` remains in `bb` (not in `if.false.orig_indirect`). A follow up patch will implement code to sink them (something like how `instcombine` would [sink](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4293) instructions along with [debug intrinsics](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4356-L4368) if possible) * The parent patch is #81181
1 parent 2f52bbe commit 5d3f296

File tree

3 files changed

+143
-10
lines changed

3 files changed

+143
-10
lines changed

llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1616

1717
namespace llvm {
18+
template <typename T> class ArrayRef;
19+
class Constant;
1820
class CallBase;
1921
class CastInst;
2022
class Function;
23+
class Instruction;
2124
class MDNode;
2225
class Value;
2326

@@ -41,7 +44,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
4144
CallBase &promoteCall(CallBase &CB, Function *Callee,
4245
CastInst **RetBitCast = nullptr);
4346

44-
/// Promote the given indirect call site to conditionally call \p Callee.
47+
/// Promote the given indirect call site to conditionally call \p Callee. The
48+
/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
49+
/// Callee`.
4550
///
4651
/// This function creates an if-then-else structure at the location of the call
4752
/// site. The original call site is moved into the "else" block. A clone of the
@@ -51,6 +56,22 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
5156
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
5257
MDNode *BranchWeights = nullptr);
5358

59+
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
60+
/// promote a virtual call is that \p VPtr is the same as any of \p
61+
/// AddressPoints.
62+
///
63+
/// This function is expected to be used on virtual calls (a subset of indirect
64+
/// calls). \p VPtr is the virtual table address stored in the objects, and
65+
/// \p AddressPoints contains vtable address points. A vtable address point is
66+
/// a location inside the vtable that's referenced by vpointer in C++ objects.
67+
///
68+
/// TODO: sink the address-calculation instructions of indirect callee to the
69+
/// indirect call fallback after transformation.
70+
CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
71+
Function *Callee,
72+
ArrayRef<Constant *> AddressPoints,
73+
MDNode *BranchWeights);
74+
5475
/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
5576
/// success.
5677
///
@@ -76,11 +97,11 @@ bool tryPromoteCall(CallBase &CB);
7697

7798
/// Predicate and clone the given call site.
7899
///
79-
/// This function creates an if-then-else structure at the location of the call
80-
/// site. The "if" condition compares the call site's called value to the given
81-
/// callee. The original call site is moved into the "else" block, and a clone
82-
/// of the call site is placed in the "then" block. The cloned instruction is
83-
/// returned.
100+
/// This function creates an if-then-else structure at the location of the
101+
/// call site. The "if" condition compares the call site's called value to
102+
/// the given callee. The original call site is moved into the "else" block,
103+
/// and a clone of the call site is placed in the "then" block. The cloned
104+
/// instruction is returned.
84105
CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);
85106

86107
} // end namespace llvm

llvm/lib/Transforms/Utils/CallPromotionUtils.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
15+
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/Analysis/Loads.h"
1617
#include "llvm/Analysis/TypeMetadataUtils.h"
1718
#include "llvm/IR/AttributeMask.h"
19+
#include "llvm/IR/Constant.h"
1820
#include "llvm/IR/IRBuilder.h"
1921
#include "llvm/IR/Instructions.h"
2022
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -188,9 +190,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
188190
/// Predicate and clone the given call site.
189191
///
190192
/// This function creates an if-then-else structure at the location of the call
191-
/// site. The "if" condition is specified by `Cond`. The original call site is
192-
/// moved into the "else" block, and a clone of the call site is placed in the
193-
/// "then" block. The cloned instruction is returned.
193+
/// site. The "if" condition is specified by `Cond`.
194+
/// The original call site is moved into the "else" block, and a clone of the
195+
/// call site is placed in the "then" block. The cloned instruction is returned.
194196
///
195197
/// For example, the call instruction below:
196198
///
@@ -518,7 +520,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
518520
Type *FormalTy = CalleeType->getParamType(ArgNo);
519521
Type *ActualTy = Arg->getType();
520522
if (FormalTy != ActualTy) {
521-
auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
523+
auto *Cast =
524+
CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
522525
CB.setArgOperand(ArgNo, Cast);
523526

524527
// Remove any incompatible attributes for the argument.
@@ -568,6 +571,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
568571
return promoteCall(NewInst, Callee);
569572
}
570573

574+
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
575+
Function *Callee,
576+
ArrayRef<Constant *> AddressPoints,
577+
MDNode *BranchWeights) {
578+
assert(!AddressPoints.empty() && "Caller should guarantee");
579+
IRBuilder<> Builder(&CB);
580+
SmallVector<Value *, 2> ICmps;
581+
for (auto &AddressPoint : AddressPoints)
582+
ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
583+
584+
// TODO: Perform tree height reduction if the number of ICmps is high.
585+
Value *Cond = Builder.CreateOr(ICmps);
586+
587+
// Version the indirect call site. If Cond is true, 'NewInst' will be
588+
// executed, otherwise the original call site will be executed.
589+
CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
590+
591+
// Promote 'NewInst' so that it directly calls the desired function.
592+
return promoteCall(NewInst, Callee);
593+
}
594+
571595
bool llvm::tryPromoteCall(CallBase &CB) {
572596
assert(!CB.getCalledFunction());
573597
Module *M = CB.getCaller()->getParent();

llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88

99
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
1010
#include "llvm/AsmParser/Parser.h"
11+
#include "llvm/IR/IRBuilder.h"
1112
#include "llvm/IR/Instructions.h"
1213
#include "llvm/IR/LLVMContext.h"
14+
#include "llvm/IR/MDBuilder.h"
1315
#include "llvm/IR/Module.h"
16+
#include "llvm/IR/NoFolder.h"
1417
#include "llvm/Support/SourceMgr.h"
1518
#include "gtest/gtest.h"
1619

@@ -24,6 +27,21 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
2427
return Mod;
2528
}
2629

30+
// Returns a constant representing the vtable's address point specified by the
31+
// offset.
32+
static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
33+
uint32_t AddressPointOffset) {
34+
Module &M = *VTable->getParent();
35+
LLVMContext &Context = M.getContext();
36+
assert(AddressPointOffset <
37+
M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
38+
"Out-of-bound access");
39+
40+
return ConstantExpr::getInBoundsGetElementPtr(
41+
Type::getInt8Ty(Context), VTable,
42+
llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
43+
}
44+
2745
TEST(CallPromotionUtilsTest, TryPromoteCall) {
2846
LLVMContext C;
2947
std::unique_ptr<Module> M = parseIR(C,
@@ -368,3 +386,73 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
368386
bool IsPromoted = tryPromoteCall(*CI);
369387
EXPECT_FALSE(IsPromoted);
370388
}
389+
390+
TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
391+
LLVMContext C;
392+
std::unique_ptr<Module> M = parseIR(C,
393+
R"IR(
394+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
395+
target triple = "x86_64-unknown-linux-gnu"
396+
397+
@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
398+
@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
399+
@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
400+
401+
define i32 @testfunc(ptr %d) {
402+
entry:
403+
%vtable = load ptr, ptr %d, !prof !7
404+
%vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
405+
%0 = load ptr, ptr %vfn
406+
%call = tail call i32 %0(ptr %d), !prof !8
407+
ret i32 %call
408+
}
409+
410+
define i32 @_ZN5Base15func1Ev(ptr %this) {
411+
entry:
412+
ret i32 2
413+
}
414+
415+
declare i32 @_ZN5Base25func2Ev(ptr)
416+
declare i32 @_ZN5Base15func0Ev(ptr)
417+
declare void @_ZN5Base35func3Ev(ptr)
418+
419+
!0 = !{i64 16, !"_ZTS5Base1"}
420+
!1 = !{i64 48, !"_ZTS5Base2"}
421+
!2 = !{i64 16, !"_ZTS8Derived1"}
422+
!3 = !{i64 64, !"_ZTS5Base1"}
423+
!4 = !{i64 40, !"_ZTS5Base2"}
424+
!5 = !{i64 16, !"_ZTS5Base3"}
425+
!6 = !{i64 16, !"_ZTS8Derived2"}
426+
!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
427+
!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
428+
429+
Function *F = M->getFunction("testfunc");
430+
CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
431+
ASSERT_TRUE(CI && CI->isIndirectCall());
432+
433+
// Create the constant and the branch weights
434+
SmallVector<Constant *, 3> VTableAddressPoints;
435+
436+
for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
437+
{"_ZTV8Derived1", 16},
438+
{"_ZTV8Derived2", 64}})
439+
VTableAddressPoints.push_back(getVTableAddressPointOffset(
440+
M->getGlobalVariable(VTableName), AddressPointOffset));
441+
442+
MDBuilder MDB(C);
443+
MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
444+
445+
size_t OrigEntryBBSize = F->front().size();
446+
447+
LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
448+
449+
Function *Callee = M->getFunction("_ZN5Base15func1Ev");
450+
// Tests that promoted direct call is returned.
451+
CallBase &DirectCB = promoteCallWithVTableCmp(
452+
*CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
453+
EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
454+
455+
// Promotion inserts 3 icmp instructions and 2 or instructions, and removes
456+
// 1 call instruction from the entry block.
457+
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
458+
}

0 commit comments

Comments
 (0)