@@ -89,7 +89,7 @@ STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated");
89
89
STATISTIC (NumCXXDtorsRemoved, " Number of global C++ destructors removed" );
90
90
STATISTIC (NumInternalFunc, " Number of internal functions" );
91
91
STATISTIC (NumColdCC, " Number of functions marked coldcc" );
92
- STATISTIC (NumIFuncsResolved, " Number of statically resolved IFuncs" );
92
+ STATISTIC (NumIFuncsResolved, " Number of resolved IFuncs" );
93
93
STATISTIC (NumIFuncsDeleted, " Number of IFuncs removed" );
94
94
95
95
static cl::opt<bool >
@@ -2462,6 +2462,228 @@ DeleteDeadIFuncs(Module &M,
2462
2462
return Changed;
2463
2463
}
2464
2464
2465
+ static Function *foldResolverForCallSite (CallBase *CS, uint64_t Priority,
2466
+ TargetTransformInfo &TTI) {
2467
+ // Look for the instruction which feeds the feature mask to the users.
2468
+ auto findRoot = [&TTI](Function *F) -> Instruction * {
2469
+ for (Instruction &I : F->getEntryBlock ())
2470
+ if (auto *Load = dyn_cast<LoadInst>(&I))
2471
+ if (Load->getPointerOperand () == TTI.getCPUFeatures (*F->getParent ()))
2472
+ return Load;
2473
+ return nullptr ;
2474
+ };
2475
+
2476
+ auto *IF = cast<GlobalIFunc>(CS->getCalledOperand ());
2477
+ Instruction *Root = findRoot (IF->getResolverFunction ());
2478
+ // There is no such instruction. Bail.
2479
+ if (!Root)
2480
+ return nullptr ;
2481
+
2482
+ // Create a constant mask to use as seed for the constant propagation.
2483
+ Constant *Seed = Constant::getIntegerValue (
2484
+ Root->getType (), APInt (Root->getType ()->getIntegerBitWidth (), Priority));
2485
+
2486
+ auto DL = CS->getModule ()->getDataLayout ();
2487
+
2488
+ // Recursively propagate on single use chains.
2489
+ std::function<Constant *(Instruction *, Instruction *, Constant *,
2490
+ BasicBlock *)>
2491
+ constFoldInst = [&](Instruction *I, Instruction *Use, Constant *C,
2492
+ BasicBlock *Pred) -> Constant * {
2493
+ // Base case.
2494
+ if (auto *Ret = dyn_cast<ReturnInst>(I))
2495
+ if (Ret->getReturnValue () == Use)
2496
+ return C;
2497
+
2498
+ // Minimal set of instruction types to handle.
2499
+ if (auto *BinOp = dyn_cast<BinaryOperator>(I)) {
2500
+ bool Swap = BinOp->getOperand (1 ) == Use;
2501
+ if (auto *Other = dyn_cast<Constant>(BinOp->getOperand (Swap ? 0 : 1 )))
2502
+ C = Swap ? ConstantFoldBinaryInstruction (BinOp->getOpcode (), Other, C)
2503
+ : ConstantFoldBinaryInstruction (BinOp->getOpcode (), C, Other);
2504
+ } else if (auto *Cmp = dyn_cast<CmpInst>(I)) {
2505
+ bool Swap = Cmp->getOperand (1 ) == Use;
2506
+ if (auto *Other = dyn_cast<Constant>(Cmp->getOperand (Swap ? 0 : 1 )))
2507
+ C = Swap ? ConstantFoldCompareInstOperands (Cmp->getPredicate (), Other,
2508
+ C, DL)
2509
+ : ConstantFoldCompareInstOperands (Cmp->getPredicate (), C,
2510
+ Other, DL);
2511
+ } else if (auto *Sel = dyn_cast<SelectInst>(I)) {
2512
+ if (Sel->getCondition () == Use)
2513
+ C = dyn_cast<Constant>(C->isZeroValue () ? Sel->getFalseValue ()
2514
+ : Sel->getTrueValue ());
2515
+ } else if (auto *Phi = dyn_cast<PHINode>(I)) {
2516
+ if (Pred)
2517
+ C = dyn_cast<Constant>(Phi->getIncomingValueForBlock (Pred));
2518
+ } else if (auto *Br = dyn_cast<BranchInst>(I)) {
2519
+ if (Br->getCondition () == Use) {
2520
+ BasicBlock *BB = Br->getSuccessor (C->isZeroValue ());
2521
+ return constFoldInst (&BB->front (), Root, Seed, Br->getParent ());
2522
+ }
2523
+ } else {
2524
+ // Don't know how to handle. Bail.
2525
+ return nullptr ;
2526
+ }
2527
+
2528
+ // Folding succeeded. Continue.
2529
+ if (C && I->hasOneUse ())
2530
+ if (auto *UI = dyn_cast<Instruction>(I->user_back ()))
2531
+ return constFoldInst (UI, I, C, nullptr );
2532
+
2533
+ return nullptr ;
2534
+ };
2535
+
2536
+ // Collect all users in the entry block ordered by proximity. The rest of
2537
+ // them can be discovered later. Unfortunately we cannot simply traverse
2538
+ // the Root's 'users()' as their order is not the same as execution order.
2539
+ unsigned NUsersLeft = std::distance (Root->user_begin (), Root->user_end ());
2540
+ SmallVector<Instruction *> Users;
2541
+ for (Instruction &I : *Root->getParent ()) {
2542
+ if (any_of (I.operands (), [Root](auto &Op) { return Op == Root; })) {
2543
+ Users.push_back (&I);
2544
+ if (--NUsersLeft == 0 )
2545
+ break ;
2546
+ }
2547
+ }
2548
+
2549
+ // Return as soon as we find a foldable user. It has the highest priority.
2550
+ for (Instruction *I : Users) {
2551
+ Constant *C = constFoldInst (I, Root, Seed, nullptr );
2552
+ if (C)
2553
+ return cast<Function>(C);
2554
+ }
2555
+
2556
+ return nullptr ;
2557
+ }
2558
+
2559
+ // Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2560
+ // deduce whether the optimization is legal we need to compare the target
2561
+ // features between caller and callee versions. The criteria for bypassing
2562
+ // the resolver are the following:
2563
+ //
2564
+ // * If the callee's feature set is a subset of the caller's feature set,
2565
+ // then the callee is a candidate for direct call.
2566
+ //
2567
+ // * Among such candidates the one of highest priority is the best match
2568
+ // and it shall be picked, unless there is a version of the callee with
2569
+ // higher priority than the best match which cannot be picked because
2570
+ // there is no corresponding caller for whom it would have been the best
2571
+ // match.
2572
+ //
2573
+ static bool OptimizeNonTrivialIFuncs (
2574
+ Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
2575
+ bool Changed = false ;
2576
+
2577
+ std::function<void (Value *, SmallVectorImpl<Function *> &)> visitValue =
2578
+ [&](Value *V, SmallVectorImpl<Function *> &FuncVersions) {
2579
+ if (auto *Func = dyn_cast<Function>(V)) {
2580
+ FuncVersions.push_back (Func);
2581
+ } else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2582
+ visitValue (Sel->getTrueValue (), FuncVersions);
2583
+ visitValue (Sel->getFalseValue (), FuncVersions);
2584
+ } else if (auto *Phi = dyn_cast<PHINode>(V))
2585
+ for (unsigned I = 0 , E = Phi->getNumIncomingValues (); I != E; ++I)
2586
+ visitValue (Phi->getIncomingValue (I), FuncVersions);
2587
+ };
2588
+
2589
+ // Cache containing the mask constructed from a function's target features.
2590
+ DenseMap<Function *, uint64_t > FeaturePriorityMap;
2591
+
2592
+ for (GlobalIFunc &IF : M.ifuncs ()) {
2593
+ if (IF.isInterposable ())
2594
+ continue ;
2595
+
2596
+ Function *Resolver = IF.getResolverFunction ();
2597
+ if (!Resolver)
2598
+ continue ;
2599
+
2600
+ if (Resolver->isInterposable ())
2601
+ continue ;
2602
+
2603
+ TargetTransformInfo &TTI = GetTTI (*Resolver);
2604
+ if (!TTI.hasFMV ())
2605
+ return false ;
2606
+
2607
+ // Discover the callee versions.
2608
+ SmallVector<Function *> Callees;
2609
+ for (BasicBlock &BB : *Resolver)
2610
+ if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator ()))
2611
+ visitValue (Ret->getReturnValue (), Callees);
2612
+
2613
+ if (Callees.empty ())
2614
+ continue ;
2615
+
2616
+ // Cache the feature mask for each callee.
2617
+ for (Function *Callee : Callees) {
2618
+ auto [It, Inserted] = FeaturePriorityMap.try_emplace (Callee);
2619
+ if (Inserted)
2620
+ It->second = TTI.getFMVPriority (*Callee);
2621
+ }
2622
+
2623
+ // Sort the callee versions in increasing feature priority order.
2624
+ // Every time we find a caller that matches the highest priority
2625
+ // callee we pop_back() one from this ordered list.
2626
+ llvm::stable_sort (Callees, [&](auto *LHS, auto *RHS) {
2627
+ return FeaturePriorityMap[LHS] < FeaturePriorityMap[RHS];
2628
+ });
2629
+
2630
+ // Find the callsites and cache the feature mask for each caller.
2631
+ SmallVector<CallBase *> CallSites;
2632
+ for (User *U : IF.users ()) {
2633
+ if (auto *CB = dyn_cast<CallBase>(U)) {
2634
+ if (CB->getCalledOperand () == &IF) {
2635
+ Function *Caller = CB->getFunction ();
2636
+ auto [It, Inserted] = FeaturePriorityMap.try_emplace (Caller);
2637
+ if (Inserted)
2638
+ It->second = TTI.getFMVPriority (*Caller);
2639
+ CallSites.push_back (CB);
2640
+ }
2641
+ }
2642
+ }
2643
+
2644
+ // Sort the callsites in decreasing feature priority order.
2645
+ llvm::stable_sort (CallSites, [&](auto *LHS, auto *RHS) {
2646
+ return FeaturePriorityMap[LHS->getFunction ()] >
2647
+ FeaturePriorityMap[RHS->getFunction ()];
2648
+ });
2649
+
2650
+ // Now try to constant fold the resolver for every callsite starting
2651
+ // from higher priority callers. This guarantees that as soon as we
2652
+ // find a callee whose priority is lower than the expected best match
2653
+ // then there is no point in continuing further.
2654
+ DenseMap<uint64_t , Function *> foldedResolverCache;
2655
+ for (CallBase *CS : CallSites) {
2656
+ uint64_t CallerPriority = FeaturePriorityMap[CS->getFunction ()];
2657
+ auto [It, Inserted] = foldedResolverCache.try_emplace (CallerPriority);
2658
+ Function *&Callee = It->second ;
2659
+ if (Inserted)
2660
+ Callee = foldResolverForCallSite (CS, CallerPriority, TTI);
2661
+ if (Callee) {
2662
+ if (!Callees.empty ()) {
2663
+ // If the priority of the candidate is greater or equal to
2664
+ // the expected best match then it shall be picked. Otherwise
2665
+ // there is a higher priority callee without a corresponding
2666
+ // caller, in which case abort.
2667
+ uint64_t CalleePriority = FeaturePriorityMap[Callee];
2668
+ if (CalleePriority == FeaturePriorityMap[Callees.back ()])
2669
+ Callees.pop_back ();
2670
+ else if (CalleePriority < FeaturePriorityMap[Callees.back ()])
2671
+ break ;
2672
+ }
2673
+ CS->setCalledOperand (Callee);
2674
+ Changed = true ;
2675
+ } else {
2676
+ // Oops, something went wrong. We couldn't fold. Abort.
2677
+ break ;
2678
+ }
2679
+ }
2680
+ if (IF.use_empty () ||
2681
+ all_of (IF.users (), [](User *U) { return isa<GlobalAlias>(U); }))
2682
+ NumIFuncsResolved++;
2683
+ }
2684
+ return Changed;
2685
+ }
2686
+
2465
2687
static bool
2466
2688
optimizeGlobalsInModule (Module &M, const DataLayout &DL,
2467
2689
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
@@ -2525,6 +2747,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
2525
2747
// Optimize IFuncs whose callee's are statically known.
2526
2748
LocalChange |= OptimizeStaticIFuncs (M);
2527
2749
2750
+ // Optimize IFuncs based on the target features of the caller.
2751
+ LocalChange |= OptimizeNonTrivialIFuncs (M, GetTTI);
2752
+
2528
2753
// Remove any IFuncs that are now dead.
2529
2754
LocalChange |= DeleteDeadIFuncs (M, NotDiscardableComdats);
2530
2755
0 commit comments