Skip to content

Commit bd6eb54

Browse files
[LLVM][CodeGen] Teach SelectionDAG how to expand FREM to a vector math call. (#83859)
This removes, at least when a vector library is available, a failure case for scalable vectors. Doing so means we can confidently cost vector FREM instructions without making an assumption that later passes will transform the IR before it gets to the code generator. NOTE: Whilst only FREM has been implemented the same mechanism can be used for the other libm related ISD nodes.
1 parent 4d478bc commit bd6eb54

File tree

3 files changed

+249
-1
lines changed

3 files changed

+249
-1
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include "llvm/ADT/DenseMap.h"
3030
#include "llvm/ADT/SmallVector.h"
31+
#include "llvm/Analysis/TargetLibraryInfo.h"
32+
#include "llvm/Analysis/VectorUtils.h"
3133
#include "llvm/CodeGen/ISDOpcodes.h"
3234
#include "llvm/CodeGen/SelectionDAG.h"
3335
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -147,6 +149,14 @@ class VectorLegalizer {
147149
void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
148150
void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
149151

152+
bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
153+
SmallVectorImpl<SDValue> &Results);
154+
bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32,
155+
RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80,
156+
RTLIB::Libcall Call_F128,
157+
RTLIB::Libcall Call_PPCF128,
158+
SmallVectorImpl<SDValue> &Results);
159+
150160
void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
151161

152162
/// Implements vector promotion.
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11391149
case ISD::VP_MERGE:
11401150
Results.push_back(ExpandVP_MERGE(Node));
11411151
return;
1152+
case ISD::FREM:
1153+
if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64,
1154+
RTLIB::REM_F80, RTLIB::REM_F128,
1155+
RTLIB::REM_PPCF128, Results))
1156+
return;
1157+
1158+
break;
11421159
}
11431160

11441161
SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
18421859
Results.push_back(Result);
18431860
}
18441861

1862+
// Try to expand libm nodes into vector math routine calls. Callers provide the
1863+
// LibFunc equivalent of the passed in Node, which is used to lookup mappings
1864+
// within TargetLibraryInfo. The only mappings considered are those where the
1865+
// result and all operands are the same vector type. While predicated nodes are
1866+
// not supported, we will emit calls to masked routines by passing in an all
1867+
// true mask.
1868+
bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC,
1869+
SmallVectorImpl<SDValue> &Results) {
1870+
// Chain must be propagated but currently strict fp operations are down
1871+
// converted to their none strict counterpart.
1872+
assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!");
1873+
1874+
const char *LCName = TLI.getLibcallName(LC);
1875+
if (!LCName)
1876+
return false;
1877+
LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n");
1878+
1879+
EVT VT = Node->getValueType(0);
1880+
ElementCount VL = VT.getVectorElementCount();
1881+
1882+
// Lookup a vector function equivalent to the specified libcall. Prefer
1883+
// unmasked variants but we will generate a mask if need be.
1884+
const TargetLibraryInfo &TLibInfo = DAG.getLibInfo();
1885+
const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false);
1886+
if (!VD)
1887+
VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true);
1888+
if (!VD)
1889+
return false;
1890+
1891+
LLVMContext *Ctx = DAG.getContext();
1892+
Type *Ty = VT.getTypeForEVT(*Ctx);
1893+
Type *ScalarTy = Ty->getScalarType();
1894+
1895+
// Construct a scalar function type based on Node's operands.
1896+
SmallVector<Type *, 8> ArgTys;
1897+
for (unsigned i = 0; i < Node->getNumOperands(); ++i) {
1898+
assert(Node->getOperand(i).getValueType() == VT &&
1899+
"Expected matching vector types!");
1900+
ArgTys.push_back(ScalarTy);
1901+
}
1902+
FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false);
1903+
1904+
// Generate call information for the vector function.
1905+
const std::string MangledName = VD->getVectorFunctionABIVariantString();
1906+
auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
1907+
if (!OptVFInfo)
1908+
return false;
1909+
1910+
LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName()
1911+
<< "\n");
1912+
1913+
// Sanity check just in case OptVFInfo has unexpected parameters.
1914+
if (OptVFInfo->Shape.Parameters.size() !=
1915+
Node->getNumOperands() + VD->isMasked())
1916+
return false;
1917+
1918+
// Collect vector call operands.
1919+
1920+
SDLoc DL(Node);
1921+
TargetLowering::ArgListTy Args;
1922+
TargetLowering::ArgListEntry Entry;
1923+
Entry.IsSExt = false;
1924+
Entry.IsZExt = false;
1925+
1926+
unsigned OpNum = 0;
1927+
for (auto &VFParam : OptVFInfo->Shape.Parameters) {
1928+
if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
1929+
EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT);
1930+
Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT);
1931+
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
1932+
Args.push_back(Entry);
1933+
continue;
1934+
}
1935+
1936+
// Only vector operands are supported.
1937+
if (VFParam.ParamKind != VFParamKind::Vector)
1938+
return false;
1939+
1940+
Entry.Node = Node->getOperand(OpNum++);
1941+
Entry.Ty = Ty;
1942+
Args.push_back(Entry);
1943+
}
1944+
1945+
// Emit a call to the vector function.
1946+
SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(),
1947+
TLI.getPointerTy(DAG.getDataLayout()));
1948+
TargetLowering::CallLoweringInfo CLI(DAG);
1949+
CLI.setDebugLoc(DL)
1950+
.setChain(DAG.getEntryNode())
1951+
.setLibCallee(CallingConv::C, Ty, Callee, std::move(Args));
1952+
1953+
std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
1954+
Results.push_back(CallResult.first);
1955+
return true;
1956+
}
1957+
1958+
/// Try to expand the node to a vector libcall based on the result type.
1959+
bool VectorLegalizer::tryExpandVecMathCall(
1960+
SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64,
1961+
RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128,
1962+
RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) {
1963+
RTLIB::Libcall LC = RTLIB::getFPLibCall(
1964+
Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64,
1965+
Call_F80, Call_F128, Call_PPCF128);
1966+
1967+
if (LC == RTLIB::UNKNOWN_LIBCALL)
1968+
return false;
1969+
1970+
return tryExpandVecMathCall(Node, LC, Results);
1971+
}
1972+
18451973
void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
18461974
SmallVectorImpl<SDValue> &Results) {
18471975
EVT VT = Node->getValueType(0);

llvm/lib/CodeGen/TargetPassConfig.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ static cl::opt<bool> MISchedPostRA(
205205
static cl::opt<bool> EarlyLiveIntervals("early-live-intervals", cl::Hidden,
206206
cl::desc("Run live interval analysis earlier in the pipeline"));
207207

208+
static cl::opt<bool> DisableReplaceWithVecLib(
209+
"disable-replace-with-vec-lib", cl::Hidden,
210+
cl::desc("Disable replace with vector math call pass"));
211+
208212
/// Option names for limiting the codegen pipeline.
209213
/// Those are used in error reporting and we didn't want
210214
/// to duplicate their names all over the place.
@@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() {
856860
if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting)
857861
addPass(createConstantHoistingPass());
858862

859-
if (getOptLevel() != CodeGenOptLevel::None)
863+
if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib)
860864
addPass(createReplaceWithVeclibLegacyPass());
861865

862866
if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s
3+
; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s
4+
5+
target triple = "aarch64-unknown-linux-gnu"
6+
7+
define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 {
8+
; ARMPL-LABEL: frem_v2f64:
9+
; ARMPL: // %bb.0:
10+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
11+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
12+
; ARMPL-NEXT: .cfi_offset w30, -16
13+
; ARMPL-NEXT: mov v0.16b, v1.16b
14+
; ARMPL-NEXT: mov v1.16b, v2.16b
15+
; ARMPL-NEXT: bl armpl_vfmodq_f64
16+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
17+
; ARMPL-NEXT: ret
18+
;
19+
; SLEEF-LABEL: frem_v2f64:
20+
; SLEEF: // %bb.0:
21+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
22+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
23+
; SLEEF-NEXT: .cfi_offset w30, -16
24+
; SLEEF-NEXT: mov v0.16b, v1.16b
25+
; SLEEF-NEXT: mov v1.16b, v2.16b
26+
; SLEEF-NEXT: bl _ZGVnN2vv_fmod
27+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
28+
; SLEEF-NEXT: ret
29+
%res = frem <2 x double> %a, %b
30+
ret <2 x double> %res
31+
}
32+
33+
define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 {
34+
; ARMPL-LABEL: frem_strict_v4f32:
35+
; ARMPL: // %bb.0:
36+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
37+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
38+
; ARMPL-NEXT: .cfi_offset w30, -16
39+
; ARMPL-NEXT: mov v0.16b, v1.16b
40+
; ARMPL-NEXT: mov v1.16b, v2.16b
41+
; ARMPL-NEXT: bl armpl_vfmodq_f32
42+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
43+
; ARMPL-NEXT: ret
44+
;
45+
; SLEEF-LABEL: frem_strict_v4f32:
46+
; SLEEF: // %bb.0:
47+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
48+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
49+
; SLEEF-NEXT: .cfi_offset w30, -16
50+
; SLEEF-NEXT: mov v0.16b, v1.16b
51+
; SLEEF-NEXT: mov v1.16b, v2.16b
52+
; SLEEF-NEXT: bl _ZGVnN4vv_fmodf
53+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
54+
; SLEEF-NEXT: ret
55+
%res = frem <4 x float> %a, %b
56+
ret <4 x float> %res
57+
}
58+
59+
define <vscale x 4 x float> @frem_nxv4f32(<vscale x 4 x float> %unused, <vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
60+
; ARMPL-LABEL: frem_nxv4f32:
61+
; ARMPL: // %bb.0:
62+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
63+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
64+
; ARMPL-NEXT: .cfi_offset w30, -16
65+
; ARMPL-NEXT: ptrue p0.s
66+
; ARMPL-NEXT: mov z0.d, z1.d
67+
; ARMPL-NEXT: mov z1.d, z2.d
68+
; ARMPL-NEXT: bl armpl_svfmod_f32_x
69+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
70+
; ARMPL-NEXT: ret
71+
;
72+
; SLEEF-LABEL: frem_nxv4f32:
73+
; SLEEF: // %bb.0:
74+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
75+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
76+
; SLEEF-NEXT: .cfi_offset w30, -16
77+
; SLEEF-NEXT: ptrue p0.s
78+
; SLEEF-NEXT: mov z0.d, z1.d
79+
; SLEEF-NEXT: mov z1.d, z2.d
80+
; SLEEF-NEXT: bl _ZGVsMxvv_fmodf
81+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
82+
; SLEEF-NEXT: ret
83+
%res = frem <vscale x 4 x float> %a, %b
84+
ret <vscale x 4 x float> %res
85+
}
86+
87+
define <vscale x 2 x double> @frem_strict_nxv2f64(<vscale x 2 x double> %unused, <vscale x 2 x double> %a, <vscale x 2 x double> %b) #1 {
88+
; ARMPL-LABEL: frem_strict_nxv2f64:
89+
; ARMPL: // %bb.0:
90+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
91+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
92+
; ARMPL-NEXT: .cfi_offset w30, -16
93+
; ARMPL-NEXT: ptrue p0.d
94+
; ARMPL-NEXT: mov z0.d, z1.d
95+
; ARMPL-NEXT: mov z1.d, z2.d
96+
; ARMPL-NEXT: bl armpl_svfmod_f64_x
97+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
98+
; ARMPL-NEXT: ret
99+
;
100+
; SLEEF-LABEL: frem_strict_nxv2f64:
101+
; SLEEF: // %bb.0:
102+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
103+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
104+
; SLEEF-NEXT: .cfi_offset w30, -16
105+
; SLEEF-NEXT: ptrue p0.d
106+
; SLEEF-NEXT: mov z0.d, z1.d
107+
; SLEEF-NEXT: mov z1.d, z2.d
108+
; SLEEF-NEXT: bl _ZGVsMxvv_fmod
109+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
110+
; SLEEF-NEXT: ret
111+
%res = frem <vscale x 2 x double> %a, %b
112+
ret <vscale x 2 x double> %res
113+
}
114+
115+
attributes #0 = { "target-features"="+sve" }
116+
attributes #1 = { "target-features"="+sve" strictfp }

0 commit comments

Comments
 (0)