Skip to content

Commit 54ffdf4

Browse files
committed
[FMV][GlobalOpt] Bypass the IFunc Resolver of MultiVersioned functions.
To deduce whether the optimization is legal we need to compare the target features between caller and callee versions. The criteria for bypassing the resolver are the following: * If the callee's feature set is a subset of the caller's feature set, then the callee is a candidate for direct call. * Among such candidates the one of highest priority is the best match and it shall be picked, unless there is a version of the callee with higher priority than the best match which cannot be picked because there is no corresponding caller for whom it would have been the best match. Implementation details: First we collect all the callee versions in feature priority order. We do the same for all the callsites. Then we try to constant fold the resolver for every callsite starting from higher priority callers. This guarantees that as soon as we find a callee whose priority is lower than the expected best match then there is no point in continuing further. The constant folding works for single basic block resolvers as well as for resolvers consisting of multiple basic blocks. The set of instructions we attempt to fold are a handful give or take (return, binop, compare, select, branch, phi) and we only follow single user use-def chains. For callsites residing in the same caller we cache the folded result to avoid redundant computation.
1 parent a522dbb commit 54ffdf4

File tree

8 files changed

+501
-3
lines changed

8 files changed

+501
-3
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,16 @@ class TargetTransformInfo {
17621762
/// false, but it shouldn't matter what it returns anyway.
17631763
bool hasArmWideBranch(bool Thumb) const;
17641764

1765+
/// Returns true if the target supports Function MultiVersioning.
1766+
bool hasFMV() const;
1767+
1768+
/// Returns the MultiVersion priority of a given function.
1769+
uint64_t getFMVPriority(Function &F) const;
1770+
1771+
/// Returns the symbol which contains the cpu feature mask used by
1772+
/// the Function MultiVersioning resolver.
1773+
GlobalVariable *getCPUFeatures(Module &M) const;
1774+
17651775
/// \return The maximum number of function arguments the target supports.
17661776
unsigned getMaxNumArgs() const;
17671777

@@ -2152,6 +2162,9 @@ class TargetTransformInfo::Concept {
21522162
virtual VPLegalization
21532163
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
21542164
virtual bool hasArmWideBranch(bool Thumb) const = 0;
2165+
virtual bool hasFMV() const = 0;
2166+
virtual uint64_t getFMVPriority(Function &F) const = 0;
2167+
virtual GlobalVariable *getCPUFeatures(Module &M) const = 0;
21552168
virtual unsigned getMaxNumArgs() const = 0;
21562169
};
21572170

@@ -2904,6 +2917,16 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
29042917
return Impl.hasArmWideBranch(Thumb);
29052918
}
29062919

2920+
bool hasFMV() const override { return Impl.hasFMV(); }
2921+
2922+
uint64_t getFMVPriority(Function &F) const override {
2923+
return Impl.getFMVPriority(F);
2924+
}
2925+
2926+
GlobalVariable *getCPUFeatures(Module &M) const override {
2927+
return Impl.getCPUFeatures(M);
2928+
}
2929+
29072930
unsigned getMaxNumArgs() const override {
29082931
return Impl.getMaxNumArgs();
29092932
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,12 @@ class TargetTransformInfoImplBase {
941941

942942
bool hasArmWideBranch(bool) const { return false; }
943943

944+
bool hasFMV() const { return false; }
945+
946+
uint64_t getFMVPriority(Function &F) const { return 0; }
947+
948+
GlobalVariable *getCPUFeatures(Module &M) const { return nullptr; }
949+
944950
unsigned getMaxNumArgs() const { return UINT_MAX; }
945951

946952
protected:

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,16 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
12961296
return TTIImpl->hasArmWideBranch(Thumb);
12971297
}
12981298

1299+
bool TargetTransformInfo::hasFMV() const { return TTIImpl->hasFMV(); }
1300+
1301+
uint64_t TargetTransformInfo::getFMVPriority(Function &F) const {
1302+
return TTIImpl->getFMVPriority(F);
1303+
}
1304+
1305+
GlobalVariable *TargetTransformInfo::getCPUFeatures(Module &M) const {
1306+
return TTIImpl->getCPUFeatures(M);
1307+
}
1308+
12991309
unsigned TargetTransformInfo::getMaxNumArgs() const {
13001310
return TTIImpl->getMaxNumArgs();
13011311
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/IR/IntrinsicsAArch64.h"
2222
#include "llvm/IR/PatternMatch.h"
2323
#include "llvm/Support/Debug.h"
24+
#include "llvm/TargetParser/AArch64TargetParser.h"
2425
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2526
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2627
#include <algorithm>
@@ -231,6 +232,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
231232
return false;
232233
}
233234

235+
uint64_t AArch64TTIImpl::getFMVPriority(Function &F) const {
236+
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
237+
SmallVector<StringRef, 8> Features;
238+
FeatureStr.split(Features, ",");
239+
return AArch64::getCpuSupportsMask(Features);
240+
}
241+
242+
GlobalVariable *AArch64TTIImpl::getCPUFeatures(Module &M) const {
243+
return M.getGlobalVariable("__aarch64_cpu_features");
244+
}
245+
234246
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
235247
const Function *Callee) const {
236248
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
8383
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
8484
unsigned DefaultCallPenalty) const;
8585

86+
bool hasFMV() const { return ST->hasFMV(); }
87+
88+
uint64_t getFMVPriority(Function &F) const;
89+
90+
GlobalVariable *getCPUFeatures(Module &M) const;
91+
8692
/// \name Scalar TTI Implementations
8793
/// @{
8894

llvm/lib/TargetParser/AArch64TargetParser.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,13 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
5050
uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs) {
5151
uint64_t FeaturesMask = 0;
5252
for (const StringRef &FeatureStr : FeatureStrs) {
53-
if (auto Ext = parseArchExtension(FeatureStr))
54-
FeaturesMask |= (1ULL << Ext->CPUFeature);
53+
StringRef Feat = resolveExtAlias(FeatureStr);
54+
for (const auto &E : Extensions) {
55+
if (Feat == E.Name || Feat == E.Feature) {
56+
FeaturesMask |= (1ULL << E.CPUFeature);
57+
break;
58+
}
59+
}
5560
}
5661
return FeaturesMask;
5762
}

llvm/lib/Transforms/IPO/GlobalOpt.cpp

Lines changed: 226 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated");
8989
STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed");
9090
STATISTIC(NumInternalFunc, "Number of internal functions");
9191
STATISTIC(NumColdCC, "Number of functions marked coldcc");
92-
STATISTIC(NumIFuncsResolved, "Number of statically resolved IFuncs");
92+
STATISTIC(NumIFuncsResolved, "Number of resolved IFuncs");
9393
STATISTIC(NumIFuncsDeleted, "Number of IFuncs removed");
9494

9595
static cl::opt<bool>
@@ -2462,6 +2462,228 @@ DeleteDeadIFuncs(Module &M,
24622462
return Changed;
24632463
}
24642464

2465+
static Function *foldResolverForCallSite(CallBase *CS, uint64_t Priority,
2466+
TargetTransformInfo &TTI) {
2467+
// Look for the instruction which feeds the feature mask to the users.
2468+
auto findRoot = [&TTI](Function *F) -> Instruction * {
2469+
for (Instruction &I : F->getEntryBlock())
2470+
if (auto *Load = dyn_cast<LoadInst>(&I))
2471+
if (Load->getPointerOperand() == TTI.getCPUFeatures(*F->getParent()))
2472+
return Load;
2473+
return nullptr;
2474+
};
2475+
2476+
auto *IF = cast<GlobalIFunc>(CS->getCalledOperand());
2477+
Instruction *Root = findRoot(IF->getResolverFunction());
2478+
// There is no such instruction. Bail.
2479+
if (!Root)
2480+
return nullptr;
2481+
2482+
// Create a constant mask to use as seed for the constant propagation.
2483+
Constant *Seed = Constant::getIntegerValue(
2484+
Root->getType(), APInt(Root->getType()->getIntegerBitWidth(), Priority));
2485+
2486+
auto DL = CS->getModule()->getDataLayout();
2487+
2488+
// Recursively propagate on single use chains.
2489+
std::function<Constant *(Instruction *, Instruction *, Constant *,
2490+
BasicBlock *)>
2491+
constFoldInst = [&](Instruction *I, Instruction *Use, Constant *C,
2492+
BasicBlock *Pred) -> Constant * {
2493+
// Base case.
2494+
if (auto *Ret = dyn_cast<ReturnInst>(I))
2495+
if (Ret->getReturnValue() == Use)
2496+
return C;
2497+
2498+
// Minimal set of instruction types to handle.
2499+
if (auto *BinOp = dyn_cast<BinaryOperator>(I)) {
2500+
bool Swap = BinOp->getOperand(1) == Use;
2501+
if (auto *Other = dyn_cast<Constant>(BinOp->getOperand(Swap ? 0 : 1)))
2502+
C = Swap ? ConstantFoldBinaryInstruction(BinOp->getOpcode(), Other, C)
2503+
: ConstantFoldBinaryInstruction(BinOp->getOpcode(), C, Other);
2504+
} else if (auto *Cmp = dyn_cast<CmpInst>(I)) {
2505+
bool Swap = Cmp->getOperand(1) == Use;
2506+
if (auto *Other = dyn_cast<Constant>(Cmp->getOperand(Swap ? 0 : 1)))
2507+
C = Swap ? ConstantFoldCompareInstOperands(Cmp->getPredicate(), Other,
2508+
C, DL)
2509+
: ConstantFoldCompareInstOperands(Cmp->getPredicate(), C,
2510+
Other, DL);
2511+
} else if (auto *Sel = dyn_cast<SelectInst>(I)) {
2512+
if (Sel->getCondition() == Use)
2513+
C = dyn_cast<Constant>(C->isZeroValue() ? Sel->getFalseValue()
2514+
: Sel->getTrueValue());
2515+
} else if (auto *Phi = dyn_cast<PHINode>(I)) {
2516+
if (Pred)
2517+
C = dyn_cast<Constant>(Phi->getIncomingValueForBlock(Pred));
2518+
} else if (auto *Br = dyn_cast<BranchInst>(I)) {
2519+
if (Br->getCondition() == Use) {
2520+
BasicBlock *BB = Br->getSuccessor(C->isZeroValue());
2521+
return constFoldInst(&BB->front(), Root, Seed, Br->getParent());
2522+
}
2523+
} else {
2524+
// Don't know how to handle. Bail.
2525+
return nullptr;
2526+
}
2527+
2528+
// Folding succeeded. Continue.
2529+
if (C && I->hasOneUse())
2530+
if (auto *UI = dyn_cast<Instruction>(I->user_back()))
2531+
return constFoldInst(UI, I, C, nullptr);
2532+
2533+
return nullptr;
2534+
};
2535+
2536+
// Collect all users in the entry block ordered by proximity. The rest of
2537+
// them can be discovered later. Unfortunately we cannot simply traverse
2538+
// the Root's 'users()' as their order is not the same as execution order.
2539+
unsigned NUsersLeft = std::distance(Root->user_begin(), Root->user_end());
2540+
SmallVector<Instruction *> Users;
2541+
for (Instruction &I : *Root->getParent()) {
2542+
if (any_of(I.operands(), [Root](auto &Op) { return Op == Root; })) {
2543+
Users.push_back(&I);
2544+
if (--NUsersLeft == 0)
2545+
break;
2546+
}
2547+
}
2548+
2549+
// Return as soon as we find a foldable user. It has the highest priority.
2550+
for (Instruction *I : Users) {
2551+
Constant *C = constFoldInst(I, Root, Seed, nullptr);
2552+
if (C)
2553+
return cast<Function>(C);
2554+
}
2555+
2556+
return nullptr;
2557+
}
2558+
2559+
// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2560+
// deduce whether the optimization is legal we need to compare the target
2561+
// features between caller and callee versions. The criteria for bypassing
2562+
// the resolver are the following:
2563+
//
2564+
// * If the callee's feature set is a subset of the caller's feature set,
2565+
// then the callee is a candidate for direct call.
2566+
//
2567+
// * Among such candidates the one of highest priority is the best match
2568+
// and it shall be picked, unless there is a version of the callee with
2569+
// higher priority than the best match which cannot be picked because
2570+
// there is no corresponding caller for whom it would have been the best
2571+
// match.
2572+
//
2573+
static bool OptimizeNonTrivialIFuncs(
2574+
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
2575+
bool Changed = false;
2576+
2577+
std::function<void(Value *, SmallVectorImpl<Function *> &)> visitValue =
2578+
[&](Value *V, SmallVectorImpl<Function *> &FuncVersions) {
2579+
if (auto *Func = dyn_cast<Function>(V)) {
2580+
FuncVersions.push_back(Func);
2581+
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2582+
visitValue(Sel->getTrueValue(), FuncVersions);
2583+
visitValue(Sel->getFalseValue(), FuncVersions);
2584+
} else if (auto *Phi = dyn_cast<PHINode>(V))
2585+
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
2586+
visitValue(Phi->getIncomingValue(I), FuncVersions);
2587+
};
2588+
2589+
// Cache containing the mask constructed from a function's target features.
2590+
DenseMap<Function *, uint64_t> FeaturePriorityMap;
2591+
2592+
for (GlobalIFunc &IF : M.ifuncs()) {
2593+
if (IF.isInterposable())
2594+
continue;
2595+
2596+
Function *Resolver = IF.getResolverFunction();
2597+
if (!Resolver)
2598+
continue;
2599+
2600+
if (Resolver->isInterposable())
2601+
continue;
2602+
2603+
TargetTransformInfo &TTI = GetTTI(*Resolver);
2604+
if (!TTI.hasFMV())
2605+
return false;
2606+
2607+
// Discover the callee versions.
2608+
SmallVector<Function *> Callees;
2609+
for (BasicBlock &BB : *Resolver)
2610+
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
2611+
visitValue(Ret->getReturnValue(), Callees);
2612+
2613+
if (Callees.empty())
2614+
continue;
2615+
2616+
// Cache the feature mask for each callee.
2617+
for (Function *Callee : Callees) {
2618+
auto [It, Inserted] = FeaturePriorityMap.try_emplace(Callee);
2619+
if (Inserted)
2620+
It->second = TTI.getFMVPriority(*Callee);
2621+
}
2622+
2623+
// Sort the callee versions in increasing feature priority order.
2624+
// Every time we find a caller that matches the highest priority
2625+
// callee we pop_back() one from this ordered list.
2626+
llvm::stable_sort(Callees, [&](auto *LHS, auto *RHS) {
2627+
return FeaturePriorityMap[LHS] < FeaturePriorityMap[RHS];
2628+
});
2629+
2630+
// Find the callsites and cache the feature mask for each caller.
2631+
SmallVector<CallBase *> CallSites;
2632+
for (User *U : IF.users()) {
2633+
if (auto *CB = dyn_cast<CallBase>(U)) {
2634+
if (CB->getCalledOperand() == &IF) {
2635+
Function *Caller = CB->getFunction();
2636+
auto [It, Inserted] = FeaturePriorityMap.try_emplace(Caller);
2637+
if (Inserted)
2638+
It->second = TTI.getFMVPriority(*Caller);
2639+
CallSites.push_back(CB);
2640+
}
2641+
}
2642+
}
2643+
2644+
// Sort the callsites in decreasing feature priority order.
2645+
llvm::stable_sort(CallSites, [&](auto *LHS, auto *RHS) {
2646+
return FeaturePriorityMap[LHS->getFunction()] >
2647+
FeaturePriorityMap[RHS->getFunction()];
2648+
});
2649+
2650+
// Now try to constant fold the resolver for every callsite starting
2651+
// from higher priority callers. This guarantees that as soon as we
2652+
// find a callee whose priority is lower than the expected best match
2653+
// then there is no point in continuing further.
2654+
DenseMap<uint64_t, Function *> foldedResolverCache;
2655+
for (CallBase *CS : CallSites) {
2656+
uint64_t CallerPriority = FeaturePriorityMap[CS->getFunction()];
2657+
auto [It, Inserted] = foldedResolverCache.try_emplace(CallerPriority);
2658+
Function *&Callee = It->second;
2659+
if (Inserted)
2660+
Callee = foldResolverForCallSite(CS, CallerPriority, TTI);
2661+
if (Callee) {
2662+
if (!Callees.empty()) {
2663+
// If the priority of the candidate is greater or equal to
2664+
// the expected best match then it shall be picked. Otherwise
2665+
// there is a higher priority callee without a corresponding
2666+
// caller, in which case abort.
2667+
uint64_t CalleePriority = FeaturePriorityMap[Callee];
2668+
if (CalleePriority == FeaturePriorityMap[Callees.back()])
2669+
Callees.pop_back();
2670+
else if (CalleePriority < FeaturePriorityMap[Callees.back()])
2671+
break;
2672+
}
2673+
CS->setCalledOperand(Callee);
2674+
Changed = true;
2675+
} else {
2676+
// Oops, something went wrong. We couldn't fold. Abort.
2677+
break;
2678+
}
2679+
}
2680+
if (IF.use_empty() ||
2681+
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
2682+
NumIFuncsResolved++;
2683+
}
2684+
return Changed;
2685+
}
2686+
24652687
static bool
24662688
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
24672689
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
@@ -2525,6 +2747,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
25252747
// Optimize IFuncs whose callee's are statically known.
25262748
LocalChange |= OptimizeStaticIFuncs(M);
25272749

2750+
// Optimize IFuncs based on the target features of the caller.
2751+
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);
2752+
25282753
// Remove any IFuncs that are now dead.
25292754
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);
25302755

0 commit comments

Comments
 (0)