Skip to content

[TLI] ReplaceWithVecLib: drop Instruction support #94365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 71 additions & 103 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
// Replaces calls to LLVM Intrinsics with matching calls to functions from a
// vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
//
//===----------------------------------------------------------------------===//

Expand All @@ -25,6 +24,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/VFABIDemangler.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
Expand Down Expand Up @@ -70,84 +70,68 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
return TLIFunc;
}

/// Replace the instruction \p I with a call to the corresponding function from
/// the vector library (\p TLIVecFunc).
static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
/// Replace the intrinsic call \p II to \p TLIVecFunc, which is the
/// corresponding function from the vector library.
static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info,
Function *TLIVecFunc) {
IRBuilder<> IRBuilder(&I);
auto *CI = dyn_cast<CallInst>(&I);
SmallVector<Value *> Args(CI ? CI->args() : I.operands());
IRBuilder<> IRBuilder(II);
SmallVector<Value *> Args(II->args());
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
auto *MaskTy =
VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
VectorType::get(Type::getInt1Ty(II->getContext()), Info.Shape.VF);
Args.insert(Args.begin() + OptMaskpos.value(),
Constant::getAllOnesValue(MaskTy));
}

// If it is a call instruction, preserve the operand bundles.
// Preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
if (CI)
CI->getOperandBundlesAsDefs(OpBundles);
II->getOperandBundlesAsDefs(OpBundles);

auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
I.replaceAllUsesWith(Replacement);
II->replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
Replacement->copyFastMathFlags(&I);
Replacement->copyFastMathFlags(II);
}

/// Returns true when successfully replaced \p I with a suitable function taking
/// vector arguments, based on available mappings in the \p TLI. Currently only
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
/// Returns true when successfully replaced \p II, which is a call to a
/// vectorized intrinsic, with a suitable function taking vector arguments,
/// based on available mappings in the \p TLI.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
IntrinsicInst *II) {
assert(II != nullptr && "Intrinsic cannot be null");
// At the moment VFABI assumes the return type is always widened unless it is
// a void type.
auto *VTy = dyn_cast<VectorType>(I.getType());
auto *VTy = dyn_cast<VectorType>(II->getType());
ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));

// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
// and checks that all vector operands match the previously found EC.
// Compute the argument types of the corresponding scalar call and check that
// all vector operands match the previously found EC.
SmallVector<Type *, 8> ScalarArgTypes;
std::string ScalarName;
Function *FuncToReplace = nullptr;
auto *CI = dyn_cast<CallInst>(&I);
if (CI) {
FuncToReplace = CI->getCalledFunction();
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
if (EC.isZero())
EC = VectorArgTy->getElementCount();
else if (EC != VectorArgTy->getElementCount())
return false;
} else
// Exit when it is supposed to be a vector argument but it isn't.
Intrinsic::ID IID = II->getIntrinsicID();
for (auto Arg : enumerate(II->args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
if (EC.isZero())
EC = VectorArgTy->getElementCount();
else if (EC != VectorArgTy->getElementCount())
return false;
}
// Try to reconstruct the name for the scalar version of the instruction,
// using scalar argument types.
ScalarName = Intrinsic::isOverloaded(IID)
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
assert(VTy && "Return type must be a vector");
auto *ScalarTy = VTy->getScalarType();
LibFunc Func;
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
} else
// Exit when it is supposed to be a vector argument but it isn't.
return false;
ScalarName = TLI.getName(Func);
ScalarArgTypes = {ScalarTy, ScalarTy};
}

// Try to reconstruct the name for the scalar version of the instruction,
// using scalar argument types.
std::string ScalarName =
Intrinsic::isOverloaded(IID)
? Intrinsic::getName(IID, ScalarArgTypes, II->getModule())
: Intrinsic::getName(IID).str();

// Try to find the mapping for the scalar version of this intrinsic and the
// exact vector width of the call operands in the TargetLibraryInfo. First,
// check with a non-masked variant, and if that fails try with a masked one.
Expand All @@ -162,7 +146,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,

// Replace the call to the intrinsic with a call to the vector library
// function.
Type *ScalarRetTy = I.getType()->getScalarType();
Type *ScalarRetTy = II->getType()->getScalarType();
FunctionType *ScalarFTy =
FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
const std::string MangledName = VD->getVectorFunctionABIVariantString();
Expand All @@ -174,68 +158,52 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// specification when being created, this is why we need to add extra check to
// make sure that the operands of the vector function obtained via VFABI match
// the operands of the original vector instruction.
if (CI) {
for (auto &VFParam : OptInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
continue;
for (auto &VFParam : OptInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
continue;

// tryDemangleForVFABI must return valid ParamPos, otherwise it could be
// a bug in the VFABI parser.
assert(VFParam.ParamPos < CI->arg_size() &&
"ParamPos has invalid range.");
Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
<< ". Wrong type at index " << VFParam.ParamPos
<< ": " << *OrigTy << "\n");
return false;
}
// tryDemangleForVFABI must return valid ParamPos, otherwise it could be
// a bug in the VFABI parser.
assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range");
Type *OrigTy = II->getArgOperand(VFParam.ParamPos)->getType();
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
<< ". Wrong type at index " << VFParam.ParamPos << ": "
<< *OrigTy << "\n");
return false;
}
}

FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
if (!VectorFTy)
return false;

Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);

replaceWithTLIFunction(I, *OptInfo, TLIFunc);
Function *TLIFunc =
getTLIFunction(II->getModule(), VectorFTy, VD->getVectorFnName(),
II->getCalledFunction());
replaceWithTLIFunction(II, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
++NumCallsReplaced;
return true;
}

/// Supported instruction \p I must be a vectorized frem or a call to an
/// intrinsic that returns either void or a vector.
static bool isSupportedInstruction(Instruction *I) {
Type *Ty = I->getType();
if (auto *CI = dyn_cast<CallInst>(I))
return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
CI->getCalledFunction()->getIntrinsicID() !=
Intrinsic::not_intrinsic;
if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
return true;
return false;
}

static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (!isSupportedInstruction(&I))
continue;
if (replaceWithCallToVeclib(TLI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
// Process only intrinsic calls that return void or a vector.
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy())
continue;

if (replaceWithCallToVeclib(TLI, II))
ReplacedCalls.push_back(&I);
}
}
// Erase the calls to the intrinsics that have been replaced
// with calls to the vector library.
for (auto *CI : ReplacedCalls)
CI->eraseFromParent();
return Changed;
// Erase any intrinsic calls that were replaced with vector library calls.
for (auto *I : ReplacedCalls)
I->eraseFromParent();
return !ReplacedCalls.empty();
}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -246,7 +214,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto Changed = runImpl(TLI, F);
if (Changed) {
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: "
<< NumCallsReplaced << "\n");

PreservedAnalyses PA;
Expand Down
42 changes: 1 addition & 41 deletions llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)

;.
; CHECK: @llvm.compiler.used = appending global [40 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x, ptr @armpl_vfmodq_f64, ptr @armpl_vfmodq_f32, ptr @armpl_svfmod_f64_x, ptr @armpl_svfmod_f32_x], section "llvm.metadata"
; CHECK: @llvm.compiler.used = appending global [36 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vtanq_f64, ptr @armpl_vtanq_f32, ptr @armpl_svtan_f64_x, ptr @armpl_svtan_f32_x], section "llvm.metadata"
;.
define <2 x double> @llvm_cos_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @llvm_cos_f64
Expand Down Expand Up @@ -469,46 +469,6 @@ define <vscale x 4 x float> @llvm_tan_vscale_f32(<vscale x 4 x float> %in) #0 {
ret <vscale x 4 x float> %1
}

define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
ret <2 x double> %1
}

define <4 x float> @frem_f32(<4 x float> %in) {
; CHECK-LABEL: define <4 x float> @frem_f32
; CHECK-SAME: (<4 x float> [[IN:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @armpl_vfmodq_f32(<4 x float> [[IN]], <4 x float> [[IN]])
; CHECK-NEXT: ret <4 x float> [[TMP1]]
;
%1= frem <4 x float> %in, %in
ret <4 x float> %1
}

define <vscale x 2 x double> @frem_vscale_f64(<vscale x 2 x double> %in) #0 {
; CHECK-LABEL: define <vscale x 2 x double> @frem_vscale_f64
; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 2 x double> @armpl_svfmod_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 2 x double> [[TMP1]]
;
%1= frem <vscale x 2 x double> %in, %in
ret <vscale x 2 x double> %1
}

define <vscale x 4 x float> @frem_vscale_f32(<vscale x 4 x float> %in) #0 {
; CHECK-LABEL: define <vscale x 4 x float> @frem_vscale_f32
; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @armpl_svfmod_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
;
%1= frem <vscale x 4 x float> %in, %in
ret <vscale x 4 x float> %1
}

attributes #0 = { "target-features"="+sve" }
;.
; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
Expand Down
20 changes: 1 addition & 19 deletions llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
target triple = "aarch64-unknown-linux-gnu"

;.
; CHECK: @llvm.compiler.used = appending global [20 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf, ptr @_ZGVsMxv_tan, ptr @_ZGVsMxv_tanf, ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata"
; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf, ptr @_ZGVsMxv_tan, ptr @_ZGVsMxv_tanf], section "llvm.metadata"
;.
define <vscale x 2 x double> @llvm_ceil_vscale_f64(<vscale x 2 x double> %in) {
; CHECK-LABEL: @llvm_ceil_vscale_f64(
Expand Down Expand Up @@ -403,24 +403,6 @@ define <vscale x 4 x float> @llvm_trunc_vscale_f32(<vscale x 4 x float> %in) {
ret <vscale x 4 x float> %1
}

define <vscale x 2 x double> @frem_f64(<vscale x 2 x double> %in) {
; CHECK-LABEL: @frem_f64(
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 2 x double> @_ZGVsMxvv_fmod(<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 2 x double> [[TMP1]]
;
%1= frem <vscale x 2 x double> %in, %in
ret <vscale x 2 x double> %1
}

define <vscale x 4 x float> @frem_f32(<vscale x 4 x float> %in) {
; CHECK-LABEL: @frem_f32(
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @_ZGVsMxvv_fmodf(<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
;
%1= frem <vscale x 4 x float> %in, %in
ret <vscale x 4 x float> %1
}

declare <vscale x 2 x double> @llvm.ceil.nxv2f64(<vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.ceil.nxv4f32(<vscale x 4 x float>)
declare <vscale x 2 x double> @llvm.copysign.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
Expand Down
Loading
Loading