Skip to content

Commit 89a8c71

Browse files
authored
[SDAG] Support expanding FSINCOS to vector library calls (#114039)
This shares most of its code with the scalar sincos expansion. It allows expanding vector FSINCOS nodes to a library call from the specified `-vector-library`. The upside of this is it will mean the vectorizer only needs to handle the sincos intrinsic, which has no memory effects, and this can handle lowering the intrinsic to a call that takes output pointers.
1 parent 3f17613 commit 89a8c71

File tree

5 files changed

+168
-70
lines changed

5 files changed

+168
-70
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

+3
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,9 @@ class SelectionDAG {
15951595
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
15961596
SDValue Op2);
15971597

1598+
/// Expand the specified \c ISD::FSINCOS node as the Legalize pass would.
1599+
bool expandFSINCOS(SDNode *Node, SmallVectorImpl<SDValue> &Results);
1600+
15981601
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
15991602
SDValue expandVAArg(SDNode *Node);
16001603

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

+1-70
Original file line numberDiff line numberDiff line change
@@ -2345,75 +2345,6 @@ static bool useSinCos(SDNode *Node) {
23452345
return false;
23462346
}
23472347

2348-
/// Issue libcalls to sincos to compute sin / cos pairs.
2349-
void SelectionDAGLegalize::ExpandSinCosLibCall(
2350-
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
2351-
EVT VT = Node->getValueType(0);
2352-
Type *Ty = VT.getTypeForEVT(*DAG.getContext());
2353-
RTLIB::Libcall LC = RTLIB::getFSINCOS(VT);
2354-
2355-
// Find users of the node that store the results (and share input chains). The
2356-
// destination pointers can be used instead of creating stack allocations.
2357-
SDValue StoresInChain{};
2358-
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
2359-
for (SDNode *User : Node->uses()) {
2360-
if (!ISD::isNormalStore(User))
2361-
continue;
2362-
auto *ST = cast<StoreSDNode>(User);
2363-
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2364-
ST->getAlign() < DAG.getDataLayout().getABITypeAlign(Ty) ||
2365-
(StoresInChain && ST->getChain() != StoresInChain) ||
2366-
Node->isPredecessorOf(ST->getChain().getNode()))
2367-
continue;
2368-
ResultStores[ST->getValue().getResNo()] = ST;
2369-
StoresInChain = ST->getChain();
2370-
}
2371-
2372-
TargetLowering::ArgListTy Args;
2373-
TargetLowering::ArgListEntry Entry{};
2374-
2375-
// Pass the argument.
2376-
Entry.Node = Node->getOperand(0);
2377-
Entry.Ty = Ty;
2378-
Args.push_back(Entry);
2379-
2380-
// Pass the output pointers for sin and cos.
2381-
SmallVector<SDValue, 2> ResultPtrs{};
2382-
for (StoreSDNode *ST : ResultStores) {
2383-
SDValue ResultPtr = ST ? ST->getBasePtr() : DAG.CreateStackTemporary(VT);
2384-
Entry.Node = ResultPtr;
2385-
Entry.Ty = PointerType::getUnqual(Ty->getContext());
2386-
Args.push_back(Entry);
2387-
ResultPtrs.push_back(ResultPtr);
2388-
}
2389-
2390-
SDLoc DL(Node);
2391-
SDValue InChain = StoresInChain ? StoresInChain : DAG.getEntryNode();
2392-
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
2393-
TLI.getPointerTy(DAG.getDataLayout()));
2394-
TargetLowering::CallLoweringInfo CLI(DAG);
2395-
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2396-
TLI.getLibcallCallingConv(LC), Type::getVoidTy(*DAG.getContext()), Callee,
2397-
std::move(Args));
2398-
2399-
auto [Call, OutChain] = TLI.LowerCallTo(CLI);
2400-
2401-
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2402-
MachinePointerInfo PtrInfo;
2403-
if (StoreSDNode *ST = ResultStores[ResNo]) {
2404-
// Replace store with the library call.
2405-
DAG.ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
2406-
PtrInfo = ST->getPointerInfo();
2407-
} else {
2408-
PtrInfo = MachinePointerInfo::getFixedStack(
2409-
DAG.getMachineFunction(),
2410-
cast<FrameIndexSDNode>(ResultPtr)->getIndex());
2411-
}
2412-
SDValue LoadResult = DAG.getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2413-
Results.push_back(LoadResult);
2414-
}
2415-
}
2416-
24172348
SDValue SelectionDAGLegalize::expandLdexp(SDNode *Node) const {
24182349
SDLoc dl(Node);
24192350
EVT VT = Node->getValueType(0);
@@ -4633,7 +4564,7 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
46334564
break;
46344565
case ISD::FSINCOS:
46354566
// Expand into sincos libcall.
4636-
ExpandSinCosLibCall(Node, Results);
4567+
(void)DAG.expandFSINCOS(Node, Results);
46374568
break;
46384569
case ISD::FLOG:
46394570
case ISD::STRICT_FLOG:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,11 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11911191
RTLIB::REM_PPCF128, Results))
11921192
return;
11931193

1194+
break;
1195+
case ISD::FSINCOS:
1196+
if (DAG.expandFSINCOS(Node, Results))
1197+
return;
1198+
11941199
break;
11951200
case ISD::VECTOR_COMPRESS:
11961201
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

+98
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/ADT/Twine.h"
2626
#include "llvm/Analysis/AliasAnalysis.h"
2727
#include "llvm/Analysis/MemoryLocation.h"
28+
#include "llvm/Analysis/TargetLibraryInfo.h"
2829
#include "llvm/Analysis/ValueTracking.h"
2930
#include "llvm/Analysis/VectorUtils.h"
3031
#include "llvm/BinaryFormat/Dwarf.h"
@@ -2483,6 +2484,103 @@ SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
24832484
return Subvectors[0];
24842485
}
24852486

2487+
bool SelectionDAG::expandFSINCOS(SDNode *Node,
2488+
SmallVectorImpl<SDValue> &Results) {
2489+
EVT VT = Node->getValueType(0);
2490+
LLVMContext *Ctx = getContext();
2491+
Type *Ty = VT.getTypeForEVT(*Ctx);
2492+
RTLIB::Libcall LC =
2493+
RTLIB::getFSINCOS(VT.isVector() ? VT.getVectorElementType() : VT);
2494+
2495+
const char *LCName = TLI->getLibcallName(LC);
2496+
if (!LC || !LCName)
2497+
return false;
2498+
2499+
auto getVecDesc = [&]() -> VecDesc const * {
2500+
for (bool Masked : {false, true}) {
2501+
if (VecDesc const *VD = getLibInfo().getVectorMappingInfo(
2502+
LCName, VT.getVectorElementCount(), Masked)) {
2503+
return VD;
2504+
}
2505+
}
2506+
return nullptr;
2507+
};
2508+
2509+
VecDesc const *VD = nullptr;
2510+
if (VT.isVector() && !(VD = getVecDesc()))
2511+
return false;
2512+
2513+
// Find users of the node that store the results (and share input chains). The
2514+
// destination pointers can be used instead of creating stack allocations.
2515+
SDValue StoresInChain{};
2516+
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
2517+
for (SDNode *User : Node->uses()) {
2518+
if (!ISD::isNormalStore(User))
2519+
continue;
2520+
auto *ST = cast<StoreSDNode>(User);
2521+
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2522+
ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
2523+
(StoresInChain && ST->getChain() != StoresInChain) ||
2524+
Node->isPredecessorOf(ST->getChain().getNode()))
2525+
continue;
2526+
ResultStores[ST->getValue().getResNo()] = ST;
2527+
StoresInChain = ST->getChain();
2528+
}
2529+
2530+
TargetLowering::ArgListTy Args;
2531+
TargetLowering::ArgListEntry Entry{};
2532+
2533+
// Pass the argument.
2534+
Entry.Node = Node->getOperand(0);
2535+
Entry.Ty = Ty;
2536+
Args.push_back(Entry);
2537+
2538+
// Pass the output pointers for sin and cos.
2539+
SmallVector<SDValue, 2> ResultPtrs{};
2540+
for (StoreSDNode *ST : ResultStores) {
2541+
SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
2542+
Entry.Node = ResultPtr;
2543+
Entry.Ty = PointerType::getUnqual(Ty->getContext());
2544+
Args.push_back(Entry);
2545+
ResultPtrs.push_back(ResultPtr);
2546+
}
2547+
2548+
SDLoc DL(Node);
2549+
2550+
if (VD && VD->isMasked()) {
2551+
EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
2552+
Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
2553+
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
2554+
Args.push_back(Entry);
2555+
}
2556+
2557+
SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
2558+
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
2559+
TLI->getPointerTy(getDataLayout()));
2560+
TargetLowering::CallLoweringInfo CLI(*this);
2561+
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2562+
TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
2563+
std::move(Args));
2564+
2565+
auto [Call, OutChain] = TLI->LowerCallTo(CLI);
2566+
2567+
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2568+
MachinePointerInfo PtrInfo;
2569+
if (StoreSDNode *ST = ResultStores[ResNo]) {
2570+
// Replace store with the library call.
2571+
ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
2572+
PtrInfo = ST->getPointerInfo();
2573+
} else {
2574+
PtrInfo = MachinePointerInfo::getFixedStack(
2575+
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
2576+
}
2577+
SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2578+
Results.push_back(LoadResult);
2579+
}
2580+
2581+
return true;
2582+
}
2583+
24862584
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
24872585
SDLoc dl(Node);
24882586
const TargetLowering &TLI = getTargetLoweringInfo();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --filter "(bl|ptrue)" --version 5
2+
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=sleefgnuabi < %s | FileCheck %s -check-prefix=SLEEF
3+
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=ArmPL < %s | FileCheck %s -check-prefix=ARMPL
4+
5+
define void @test_sincos_v4f32(<4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
6+
; SLEEF-LABEL: test_sincos_v4f32:
7+
; SLEEF: bl _ZGVnN4vl4l4_sincosf
8+
;
9+
; ARMPL-LABEL: test_sincos_v4f32:
10+
; ARMPL: bl armpl_vsincosq_f32
11+
%result = call { <4 x float>, <4 x float> } @llvm.sincos.v4f32(<4 x float> %x)
12+
%result.0 = extractvalue { <4 x float>, <4 x float> } %result, 0
13+
%result.1 = extractvalue { <4 x float>, <4 x float> } %result, 1
14+
store <4 x float> %result.0, ptr %out_sin, align 4
15+
store <4 x float> %result.1, ptr %out_cos, align 4
16+
ret void
17+
}
18+
19+
define void @test_sincos_v2f64(<2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
20+
; SLEEF-LABEL: test_sincos_v2f64:
21+
; SLEEF: bl _ZGVnN2vl8l8_sincos
22+
;
23+
; ARMPL-LABEL: test_sincos_v2f64:
24+
; ARMPL: bl armpl_vsincosq_f64
25+
%result = call { <2 x double>, <2 x double> } @llvm.sincos.v2f64(<2 x double> %x)
26+
%result.0 = extractvalue { <2 x double>, <2 x double> } %result, 0
27+
%result.1 = extractvalue { <2 x double>, <2 x double> } %result, 1
28+
store <2 x double> %result.0, ptr %out_sin, align 8
29+
store <2 x double> %result.1, ptr %out_cos, align 8
30+
ret void
31+
}
32+
33+
define void @test_sincos_nxv4f32(<vscale x 4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
34+
; SLEEF-LABEL: test_sincos_nxv4f32:
35+
; SLEEF: bl _ZGVsNxvl4l4_sincosf
36+
;
37+
; ARMPL-LABEL: test_sincos_nxv4f32:
38+
; ARMPL: ptrue p0.s
39+
; ARMPL: bl armpl_svsincos_f32_x
40+
%result = call { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.sincos.nxv4f32(<vscale x 4 x float> %x)
41+
%result.0 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 0
42+
%result.1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 1
43+
store <vscale x 4 x float> %result.0, ptr %out_sin, align 4
44+
store <vscale x 4 x float> %result.1, ptr %out_cos, align 4
45+
ret void
46+
}
47+
48+
define void @test_sincos_nxv2f64(<vscale x 2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
49+
; SLEEF-LABEL: test_sincos_nxv2f64:
50+
; SLEEF: bl _ZGVsNxvl8l8_sincos
51+
;
52+
; ARMPL-LABEL: test_sincos_nxv2f64:
53+
; ARMPL: ptrue p0.d
54+
; ARMPL: bl armpl_svsincos_f64_x
55+
%result = call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.sincos.nxv2f64(<vscale x 2 x double> %x)
56+
%result.0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 0
57+
%result.1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 1
58+
store <vscale x 2 x double> %result.0, ptr %out_sin, align 8
59+
store <vscale x 2 x double> %result.1, ptr %out_cos, align 8
60+
ret void
61+
}

0 commit comments

Comments
 (0)