-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. #83859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. #83859
Conversation
…h call. This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator. NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes.
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: Paul Walker (paulwalker-arm) ChangesThis removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator. NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes. Full diff: https://github.com/llvm/llvm-project/pull/83859.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6074498d9144ff..ebd6f62a63ac4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -28,6 +28,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+ bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+ SmallVectorImpl<SDValue> &Results);
+ bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
+ RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
+ RTLIB::Libcall Call_F128,
+ RTLIB::Libcall Call_PPCF128,
+ SmallVectorImpl<SDValue> &Results);
+
void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
/// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VP_MERGE:
Results.push_back(ExpandVP_MERGE(Node));
return;
+ case ISD::FREM:
+ if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
+ RTLIB::REM_F80, RTLIB::REM_F128,
+ RTLIB::REM_PPCF128, Results))
+ return;
+
+ break;
}
SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,116 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
Results.push_back(Result);
}
+// Try to expand libm nodes into a call to a vector math. Callers provide the
+// LibFunc equivalent of the passed in Node, which is used to lookup mappings
+// within TargetLibraryInfo. Only simply mappings are considered whereby only
+// matching vector operands are allowed and masked functions are passed an all
+// true vector (i.e. Node cannot be a predicated operation).
+bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
+ SmallVectorImpl<SDValue> &Results) {
+ // Chain must be propagated but currently strict fp operations are down
+ // converted to their none strict counterpart.
+ assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
+
+ const char *LCName = TLI.getLibcallName(LC);
+ if (!LCName)
+ return false;
+ LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
+
+ EVT VT = Node->getValueType(0);
+ ElementCount VL = VT.getVectorElementCount();
+
+ // Lookup a vector function equivalent to the specified libcall. Prefer
+ // unmasked variants but we will generate a mask if need be.
+ const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
+ const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
+ if (!VD)
+ VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true);
+ if (!VD)
+ return false;
+
+ LLVMContext *Ctx = DAG.getContext();
+ Type *Ty = VT.getTypeForEVT(*Ctx);
+ Type *ScalarTy = Ty->getScalarType();
+
+ // Construct a scalar function type based on Node's operands.
+ SmallVector<Type *, 8> ArgTys;
+ for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
+ assert(Node->getOperand(i).getValueType() == VT &&
+ "Expected matching vector types!");
+ ArgTys.push_back(ScalarTy);
+ }
+ FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
+
+ // Generate call information for the vector function.
+ const std::string MangledName = VD->getVectorFunctionABIVariantString();
+ auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+ if (!OptVFInfo)
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
+ << "\n");
+
+ // Sanity check just in case OptVFInfo has unexpected paramaters.
+ if (OptVFInfo->Shape.Parameters.size() !=
+ Node->getNumOperands() + VD->isMasked())
+ return false;
+
+ // Collect vector call operands.
+
+ SDLoc DL(Node);
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.IsSExt = false;
+ Entry.IsZExt = false;
+
+ unsigned OpNum = 0;
+ for (auto &VFParam : OptVFInfo->Shape.Parameters) {
+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
+ EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
+ Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
+ Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
+ Args.push_back(Entry);
+ continue;
+ }
+
+ // Only vector operands are supported.
+ if (VFParam.ParamKind != VFParamKind::Vector)
+ return false;
+
+ Entry.Node = Node->getOperand(OpNum++);
+ Entry.Ty = Ty;
+ Args.push_back(Entry);
+ }
+
+ // Emit a call to the vector function.
+ SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
+ TLI.getPointerTy(DAG.getDataLayout()));
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL)
+ .setChain(DAG.getEntryNode())
+ .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
+
+ std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
+ Results.push_back(CallResult.first);
+ return true;
+}
+
+/// Try to expand the node to a vector libcall based on the result type.
+bool VectorLegalizer::tryExpandVecMathCall(
+ SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
+ RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
+ RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
+ RTLIB::Libcall LC = RTLIB::getFPLibCall(
+ Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
+ Call_F80, Call_F128, Call_PPCF128);
+
+ if (LC == RTLIB::UNKNOWN_LIBCALL)
+ return false;
+
+ return tryExpandVecMathCall(Node, LC, Results);
+}
+
void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index cf068ece8d4cab..8832b51333d910 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
cl::desc("Run live interval analysis earlier in the pipeline"));
+static cl::opt<bool> DisableReplaceWithVecLib(
+ "disable-replace-with-vec-lib", cl::Hidden,
+ cl::desc("Disable replace with vector math call pass"));
+
/// Option names for limiting the codegen pipeline.
/// Those are used in error reporting and we didn't want
/// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
addPass(createConstantHoistingPass());
- if (getOptLevel() != CodeGenOptLevel::None)
+ if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
addPass(createReplaceWithVeclibLegacyPass());
if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
diff --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
new file mode 100644
index 00000000000000..67c056c780cc80
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
+; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
+; ARMPL-LABEL: frem_v2f64:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: .cfi_def_cfa_offset 16
+; ARMPL-NEXT: .cfi_offset w30, -16
+; ARMPL-NEXT: mov v0.16b, v1.16b
+; ARMPL-NEXT: mov v1.16b, v2.16b
+; ARMPL-NEXT: bl armpl_vfmodq_f64
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+;
+; SLEEF-LABEL: frem_v2f64:
+; SLEEF: // %bb.0:
+; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT: .cfi_def_cfa_offset 16
+; SLEEF-NEXT: .cfi_offset w30, -16
+; SLEEF-NEXT: mov v0.16b, v1.16b
+; SLEEF-NEXT: mov v1.16b, v2.16b
+; SLEEF-NEXT: bl _ZGVnN2vv_fmod
+; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT: ret
+ %res = frem <2 x double> %a, %b
+ ret <2 x double> %res
+}
+
+define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
+; ARMPL-LABEL: frem_strict_v4f32:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: .cfi_def_cfa_offset 16
+; ARMPL-NEXT: .cfi_offset w30, -16
+; ARMPL-NEXT: mov v0.16b, v1.16b
+; ARMPL-NEXT: mov v1.16b, v2.16b
+; ARMPL-NEXT: bl armpl_vfmodq_f32
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+;
+; SLEEF-LABEL: frem_strict_v4f32:
+; SLEEF: // %bb.0:
+; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT: .cfi_def_cfa_offset 16
+; SLEEF-NEXT: .cfi_offset w30, -16
+; SLEEF-NEXT: mov v0.16b, v1.16b
+; SLEEF-NEXT: mov v1.16b, v2.16b
+; SLEEF-NEXT: bl _ZGVnN4vv_fmodf
+; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT: ret
+ %res = frem <4 x float> %a, %b
+ ret <4 x float> %res
+}
+
+define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
+; ARMPL-LABEL: frem_nxv4f32:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: .cfi_def_cfa_offset 16
+; ARMPL-NEXT: .cfi_offset w30, -16
+; ARMPL-NEXT: ptrue p0.s
+; ARMPL-NEXT: mov z0.d, z1.d
+; ARMPL-NEXT: mov z1.d, z2.d
+; ARMPL-NEXT: bl armpl_svfmod_f32_x
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+;
+; SLEEF-LABEL: frem_nxv4f32:
+; SLEEF: // %bb.0:
+; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT: .cfi_def_cfa_offset 16
+; SLEEF-NEXT: .cfi_offset w30, -16
+; SLEEF-NEXT: ptrue p0.s
+; SLEEF-NEXT: mov z0.d, z1.d
+; SLEEF-NEXT: mov z1.d, z2.d
+; SLEEF-NEXT: bl _ZGVsMxvv_fmodf
+; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT: ret
+ %res = frem <vscale x 4 x float> %a, %b
+ ret <vscale x 4 x float> %res
+}
+
+define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
+; ARMPL-LABEL: frem_strict_nxv2f64:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: .cfi_def_cfa_offset 16
+; ARMPL-NEXT: .cfi_offset w30, -16
+; ARMPL-NEXT: ptrue p0.d
+; ARMPL-NEXT: mov z0.d, z1.d
+; ARMPL-NEXT: mov z1.d, z2.d
+; ARMPL-NEXT: bl armpl_svfmod_f64_x
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+;
+; SLEEF-LABEL: frem_strict_nxv2f64:
+; SLEEF: // %bb.0:
+; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; SLEEF-NEXT: .cfi_def_cfa_offset 16
+; SLEEF-NEXT: .cfi_offset w30, -16
+; SLEEF-NEXT: ptrue p0.d
+; SLEEF-NEXT: mov z0.d, z1.d
+; SLEEF-NEXT: mov z1.d, z2.d
+; SLEEF-NEXT: bl _ZGVsMxvv_fmod
+; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; SLEEF-NEXT: ret
+ %res = frem <vscale x 2 x double> %a, %b
+ ret <vscale x 2 x double> %res
+}
+
+attributes #0 = { "target-features"="+sve" }
+attributes #1 = { "target-features"="+sve" strictfp }
|
const TargetLibraryInfo &TLibInfo = DAG.getLibInfo(); | ||
const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false); | ||
if (!VD) | ||
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked*/ true); | |
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes.
|
||
/// Try to expand the node to a vector libcall based on the result type. | ||
bool VectorLegalizer::tryExpandVecMathCall( | ||
SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason the Libcall variants are parameters to this function, instead of being hardcoded? (Looking at other functions like DAGTypeLegalizer::SoftenFloatRes_FREM
I see the list being hardcoded).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm expecting this function to be used by other ISD nodes (e.g. ISD::FSIN
, ISD::FSIN
etc) so I followed the idiom used by ExpandFPLibCall
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Refactor the pass to only support IntrinsicInst calls. ReplaceWithVecLib used to support instructions, as AArch64 was using this pass to replace a vectorized frem instruction to the fmod vector library call (through TLI). As this replacement is now done by the codegen (#83859), there is no need for this pass to support instructions. Additionally, removed 'frem' tests from: - AArch64/replace-with-veclib-armpl.ll - AArch64/replace-with-veclib-sleef-scalable.ll - AArch64/replace-with-veclib-sleef.l Such testing is done at codegen level: - #83859
Refactor the pass to only support IntrinsicInst calls. ReplaceWithVecLib used to support instructions, as AArch64 was using this pass to replace a vectorized frem instruction to the fmod vector library call (through TLI). As this replacement is now done by the codegen (#83859), there is no need for this pass to support instructions. Additionally, removed 'frem' tests from: - AArch64/replace-with-veclib-armpl.ll - AArch64/replace-with-veclib-sleef-scalable.ll - AArch64/replace-with-veclib-sleef.ll Such testing is done at codegen level: - #83859
Refactor the pass to only support `IntrinsicInst` calls. `ReplaceWithVecLib` used to support instructions, as AArch64 was using this pass to replace a vectorized frem instruction to the fmod vector library call (through TLI). As this replacement is now done by the codegen (#83859), there is no need for this pass to support instructions. Additionally, removed 'frem' tests from: - AArch64/replace-with-veclib-armpl.ll - AArch64/replace-with-veclib-sleef-scalable.ll - AArch64/replace-with-veclib-sleef.ll Such testing is done at codegen level: - #83859
This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator.
NOTE: Currently only FREM has been implemented but the same mechanism can be used for the other libm related ISD nodes.