@@ -7659,32 +7659,38 @@ buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
7659
7659
}
7660
7660
7661
7661
/// Calculates the costs of vectorized intrinsic (if possible) and vectorized
7662
- /// function (if possible) calls.
7662
+ /// function (if possible) calls. Returns invalid cost for the corresponding
7663
+ /// calls, if they cannot be vectorized/will be scalarized.
7663
7664
static std::pair<InstructionCost, InstructionCost>
7664
7665
getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
7665
7666
TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
7666
7667
ArrayRef<Type *> ArgTys) {
7667
- Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
7668
-
7669
- // Calculate the cost of the scalar and vector calls.
7670
- FastMathFlags FMF;
7671
- if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
7672
- FMF = FPCI->getFastMathFlags();
7673
- IntrinsicCostAttributes CostAttrs(ID, VecTy, ArgTys, FMF);
7674
- auto IntrinsicCost =
7675
- TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
7676
-
7677
7668
auto Shape = VFShape::get(CI->getFunctionType(),
7678
7669
ElementCount::getFixed(VecTy->getNumElements()),
7679
7670
false /*HasGlobalPred*/);
7680
7671
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
7681
- auto LibCost = IntrinsicCost ;
7672
+ auto LibCost = InstructionCost::getInvalid() ;
7682
7673
if (!CI->isNoBuiltin() && VecFunc) {
7683
7674
// Calculate the cost of the vector library call.
7684
7675
// If the corresponding vector call is cheaper, return its cost.
7685
7676
LibCost =
7686
7677
TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput);
7687
7678
}
7679
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
7680
+
7681
+ // Calculate the cost of the vector intrinsic call.
7682
+ FastMathFlags FMF;
7683
+ if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
7684
+ FMF = FPCI->getFastMathFlags();
7685
+ const InstructionCost ScalarLimit = 10000;
7686
+ IntrinsicCostAttributes CostAttrs(ID, VecTy, ArgTys, FMF, nullptr,
7687
+ LibCost.isValid() ? LibCost : ScalarLimit);
7688
+ auto IntrinsicCost =
7689
+ TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
7690
+ if ((LibCost.isValid() && IntrinsicCost > LibCost) ||
7691
+ (!LibCost.isValid() && IntrinsicCost > ScalarLimit))
7692
+ IntrinsicCost = InstructionCost::getInvalid();
7693
+
7688
7694
return {IntrinsicCost, LibCost};
7689
7695
}
7690
7696
@@ -8028,6 +8034,12 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
8028
8034
return TreeEntry::NeedToGather;
8029
8035
}
8030
8036
}
8037
+ SmallVector<Type *> ArgTys =
8038
+ buildIntrinsicArgTypes(CI, ID, VL.size(), 0, TTI);
8039
+ auto *VecTy = getWidenedType(S.getMainOp()->getType(), VL.size());
8040
+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
8041
+ if (!VecCallCosts.first.isValid() && !VecCallCosts.second.isValid())
8042
+ return TreeEntry::NeedToGather;
8031
8043
8032
8044
return TreeEntry::Vectorize;
8033
8045
}
0 commit comments