@@ -43,24 +43,24 @@ STATISTIC(NumFuncUsedAdded,
4343 " Number of functions added to `llvm.compiler.used`" );
4444
4545// / Returns a vector Function that it adds to the Module \p M. When an \p
46- // / OptOldFunc is given, it copies its attributes to the newly created Function.
46+ // / ScalarFunc is not null, it copies its attributes to the newly created
47+ // / Function.
4748Function *getTLIFunction (Module *M, FunctionType *VectorFTy,
48- std::optional<Function *> OptOldFunc,
49- const StringRef TLIName) {
49+ Function *ScalarFunc, const StringRef TLIName) {
5050 Function *TLIFunc = M->getFunction (TLIName);
5151 if (!TLIFunc) {
5252 TLIFunc =
5353 Function::Create (VectorFTy, Function::ExternalLinkage, TLIName, *M);
54- if (OptOldFunc )
55- TLIFunc->copyAttributesFrom (*OptOldFunc );
54+ if (ScalarFunc )
55+ TLIFunc->copyAttributesFrom (ScalarFunc );
5656
5757 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Added vector library function `"
5858 << TLIName << " ` of type `" << *(TLIFunc->getType ())
5959 << " ` to module.\n " );
6060
6161 ++NumTLIFuncDeclAdded;
6262 // Add the freshly created function to llvm.compiler.used, similar to as it
63- // is done in InjectTLIMappings
63+ // is done in InjectTLIMappings.
6464 appendToCompilerUsed (*M, {TLIFunc});
6565 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Adding `" << TLIName
6666 << " ` to `@llvm.compiler.used`.\n " );
@@ -72,11 +72,11 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
7272// / Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
7373// / the corresponding function from the vector library ( \p TLIFunc ).
7474static void replaceWithTLIFunction (CallInst &CI, VFInfo &Info,
75- Function *TLIFunc, FunctionType *VecFTy ) {
75+ Function *TLIVecFunc ) {
7676 IRBuilder<> IRBuilder (&CI);
7777 SmallVector<Value *> Args (CI.args ());
7878 if (auto OptMaskpos = Info.getParamIndexForOptionalMask ()) {
79- if (Args.size () == VecFTy ->getNumParams ())
79+ if (Args.size () == TLIVecFunc-> getFunctionType () ->getNumParams ())
8080 static_assert (true && " mask was already in place" );
8181
8282 auto *MaskTy =
@@ -88,9 +88,7 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
8888 // Preserve the operand bundles.
8989 SmallVector<OperandBundleDef, 1 > OpBundles;
9090 CI.getOperandBundlesAsDefs (OpBundles);
91- CallInst *Replacement = IRBuilder.CreateCall (TLIFunc, Args, OpBundles);
92- assert (VecFTy == TLIFunc->getFunctionType () &&
93- " Expecting function types to be identical" );
91+ CallInst *Replacement = IRBuilder.CreateCall (TLIVecFunc, Args, OpBundles);
9492 CI.replaceAllUsesWith (Replacement);
9593 // Preserve fast math flags for FP math.
9694 if (isa<FPMathOperator>(Replacement))
@@ -102,10 +100,10 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
102100static std::optional<const VecDesc *> getVecDesc (const TargetLibraryInfo &TLI,
103101 const StringRef &ScalarName,
104102 const ElementCount &VF) {
105- if (auto *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true ))
106- return VDMasked;
107103 if (auto *VDNoMask = TLI.getVectorMappingInfo (ScalarName, VF, false ))
108104 return VDNoMask;
105+ if (auto *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true ))
106+ return VDMasked;
109107 return std::nullopt ;
110108}
111109
@@ -117,20 +115,20 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
117115 return false ;
118116
119117 auto IntrinsicID = CI.getCalledFunction ()->getIntrinsicID ();
120- // Replacement is only performed for intrinsic functions
118+ // Replacement is only performed for intrinsic functions.
121119 if (IntrinsicID == Intrinsic::not_intrinsic)
122120 return false ;
123121
124122 // Convert vector arguments to scalar type and check that all vector operands
125123 // have identical vector width.
126124 ElementCount VF = ElementCount::getFixed (0 );
127- SmallVector<Type *> ScalarTypes ;
125+ SmallVector<Type *> ScalarArgTypes ;
128126 for (auto Arg : enumerate(CI.args ())) {
129127 auto *ArgTy = Arg.value ()->getType ();
130128 if (isVectorIntrinsicWithScalarOpAtArg (IntrinsicID, Arg.index ())) {
131- ScalarTypes .push_back (ArgTy);
129+ ScalarArgTypes .push_back (ArgTy);
132130 } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
133- ScalarTypes .push_back (ArgTy->getScalarType ());
131+ ScalarArgTypes .push_back (ArgTy->getScalarType ());
134132 // Disallow vector arguments with different VFs. When processing the first
135133 // vector argument, store it's VF, and for the rest ensure that they match
136134 // it.
@@ -139,15 +137,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
139137 else if (VF != VectorArgTy->getElementCount ())
140138 return false ;
141139 } else
142- // enters when it is supposed to be a vector argument but it isn't.
140+ // Exit when it is supposed to be a vector argument but it isn't.
143141 return false ;
144142 }
145143
146144 // Try to reconstruct the name for the scalar version of this intrinsic using
147145 // the intrinsic ID and the argument types converted to scalar above.
148146 std::string ScalarName =
149147 (Intrinsic::isOverloaded (IntrinsicID)
150- ? Intrinsic::getName (IntrinsicID, ScalarTypes , CI.getModule ())
148+ ? Intrinsic::getName (IntrinsicID, ScalarArgTypes , CI.getModule ())
151149 : Intrinsic::getName (IntrinsicID).str ());
152150
153151 // The TargetLibraryInfo does not contain a vectorized version of the scalar
@@ -169,7 +167,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
169167 // Replace the call to the intrinsic with a call to the vector library
170168 // function.
171169 Type *ScalarRetTy = CI.getType ()->getScalarType ();
172- FunctionType *ScalarFTy = FunctionType::get (ScalarRetTy, ScalarTypes, false );
170+ FunctionType *ScalarFTy =
171+ FunctionType::get (ScalarRetTy, ScalarArgTypes, /* isVarArg*/ false );
173172 const std::string MangledName = VD->getVectorFunctionABIVariantString ();
174173 auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
175174 if (!OptInfo)
@@ -182,7 +181,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
182181 Function *FuncToReplace = CI.getCalledFunction ();
183182 Function *TLIFunc = getTLIFunction (CI.getModule (), VectorFTy, FuncToReplace,
184183 VD->getVectorFnName ());
185- replaceWithTLIFunction (CI, *OptInfo, TLIFunc, VectorFTy );
184+ replaceWithTLIFunction (CI, *OptInfo, TLIFunc);
186185
187186 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
188187 << FuncToReplace->getName () << " ` with call to `"
0 commit comments