|
28 | 28 |
|
29 | 29 | #include "llvm/ADT/DenseMap.h"
|
30 | 30 | #include "llvm/ADT/SmallVector.h"
|
| 31 | +#include "llvm/Analysis/TargetLibraryInfo.h" |
| 32 | +#include "llvm/Analysis/VectorUtils.h" |
31 | 33 | #include "llvm/CodeGen/ISDOpcodes.h"
|
32 | 34 | #include "llvm/CodeGen/SelectionDAG.h"
|
33 | 35 | #include "llvm/CodeGen/SelectionDAGNodes.h"
|
@@ -147,6 +149,14 @@ class VectorLegalizer {
|
147 | 149 | void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
|
148 | 150 | void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
|
149 | 151 |
|
| 152 | + bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, |
| 153 | + SmallVectorImpl<SDValue> &Results); |
| 154 | + bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32, |
| 155 | + RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80, |
| 156 | + RTLIB::Libcall Call_F128, |
| 157 | + RTLIB::Libcall Call_PPCF128, |
| 158 | + SmallVectorImpl<SDValue> &Results); |
| 159 | + |
150 | 160 | void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
|
151 | 161 |
|
152 | 162 | /// Implements vector promotion.
|
@@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
|
1139 | 1149 | case ISD::VP_MERGE:
|
1140 | 1150 | Results.push_back(ExpandVP_MERGE(Node));
|
1141 | 1151 | return;
|
| 1152 | + case ISD::FREM: |
| 1153 | + if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64, |
| 1154 | + RTLIB::REM_F80, RTLIB::REM_F128, |
| 1155 | + RTLIB::REM_PPCF128, Results)) |
| 1156 | + return; |
| 1157 | + |
| 1158 | + break; |
1142 | 1159 | }
|
1143 | 1160 |
|
1144 | 1161 | SDValue Unrolled = DAG.UnrollVectorOp(Node);
|
@@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node,
|
1842 | 1859 | Results.push_back(Result);
|
1843 | 1860 | }
|
1844 | 1861 |
|
| 1862 | +// Try to expand libm nodes into vector math routine calls. Callers provide the |
| 1863 | +// LibFunc equivalent of the passed in Node, which is used to lookup mappings |
| 1864 | +// within TargetLibraryInfo. The only mappings considered are those where the |
| 1865 | +// result and all operands are the same vector type. While predicated nodes are |
| 1866 | +// not supported, we will emit calls to masked routines by passing in an all |
| 1867 | +// true mask. |
| 1868 | +bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, |
| 1869 | + SmallVectorImpl<SDValue> &Results) { |
| 1870 | + // Chain must be propagated but currently strict fp operations are down |
| 1871 | + // converted to their none strict counterpart. |
| 1872 | + assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!"); |
| 1873 | + |
| 1874 | + const char *LCName = TLI.getLibcallName(LC); |
| 1875 | + if (!LCName) |
| 1876 | + return false; |
| 1877 | + LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n"); |
| 1878 | + |
| 1879 | + EVT VT = Node->getValueType(0); |
| 1880 | + ElementCount VL = VT.getVectorElementCount(); |
| 1881 | + |
| 1882 | + // Lookup a vector function equivalent to the specified libcall. Prefer |
| 1883 | + // unmasked variants but we will generate a mask if need be. |
| 1884 | + const TargetLibraryInfo &TLibInfo = DAG.getLibInfo(); |
| 1885 | + const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false); |
| 1886 | + if (!VD) |
| 1887 | + VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true); |
| 1888 | + if (!VD) |
| 1889 | + return false; |
| 1890 | + |
| 1891 | + LLVMContext *Ctx = DAG.getContext(); |
| 1892 | + Type *Ty = VT.getTypeForEVT(*Ctx); |
| 1893 | + Type *ScalarTy = Ty->getScalarType(); |
| 1894 | + |
| 1895 | + // Construct a scalar function type based on Node's operands. |
| 1896 | + SmallVector<Type *, 8> ArgTys; |
| 1897 | + for (unsigned i = 0; i < Node->getNumOperands(); ++i) { |
| 1898 | + assert(Node->getOperand(i).getValueType() == VT && |
| 1899 | + "Expected matching vector types!"); |
| 1900 | + ArgTys.push_back(ScalarTy); |
| 1901 | + } |
| 1902 | + FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false); |
| 1903 | + |
| 1904 | + // Generate call information for the vector function. |
| 1905 | + const std::string MangledName = VD->getVectorFunctionABIVariantString(); |
| 1906 | + auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy); |
| 1907 | + if (!OptVFInfo) |
| 1908 | + return false; |
| 1909 | + |
| 1910 | + LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName() |
| 1911 | + << "\n"); |
| 1912 | + |
| 1913 | + // Sanity check just in case OptVFInfo has unexpected parameters. |
| 1914 | + if (OptVFInfo->Shape.Parameters.size() != |
| 1915 | + Node->getNumOperands() + VD->isMasked()) |
| 1916 | + return false; |
| 1917 | + |
| 1918 | + // Collect vector call operands. |
| 1919 | + |
| 1920 | + SDLoc DL(Node); |
| 1921 | + TargetLowering::ArgListTy Args; |
| 1922 | + TargetLowering::ArgListEntry Entry; |
| 1923 | + Entry.IsSExt = false; |
| 1924 | + Entry.IsZExt = false; |
| 1925 | + |
| 1926 | + unsigned OpNum = 0; |
| 1927 | + for (auto &VFParam : OptVFInfo->Shape.Parameters) { |
| 1928 | + if (VFParam.ParamKind == VFParamKind::GlobalPredicate) { |
| 1929 | + EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT); |
| 1930 | + Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT); |
| 1931 | + Entry.Ty = MaskVT.getTypeForEVT(*Ctx); |
| 1932 | + Args.push_back(Entry); |
| 1933 | + continue; |
| 1934 | + } |
| 1935 | + |
| 1936 | + // Only vector operands are supported. |
| 1937 | + if (VFParam.ParamKind != VFParamKind::Vector) |
| 1938 | + return false; |
| 1939 | + |
| 1940 | + Entry.Node = Node->getOperand(OpNum++); |
| 1941 | + Entry.Ty = Ty; |
| 1942 | + Args.push_back(Entry); |
| 1943 | + } |
| 1944 | + |
| 1945 | + // Emit a call to the vector function. |
| 1946 | + SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(), |
| 1947 | + TLI.getPointerTy(DAG.getDataLayout())); |
| 1948 | + TargetLowering::CallLoweringInfo CLI(DAG); |
| 1949 | + CLI.setDebugLoc(DL) |
| 1950 | + .setChain(DAG.getEntryNode()) |
| 1951 | + .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args)); |
| 1952 | + |
| 1953 | + std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI); |
| 1954 | + Results.push_back(CallResult.first); |
| 1955 | + return true; |
| 1956 | +} |
| 1957 | + |
| 1958 | +/// Try to expand the node to a vector libcall based on the result type. |
| 1959 | +bool VectorLegalizer::tryExpandVecMathCall( |
| 1960 | + SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64, |
| 1961 | + RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128, |
| 1962 | + RTLIB::Libcall Call_PPCF128, SmallVectorImpl<SDValue> &Results) { |
| 1963 | + RTLIB::Libcall LC = RTLIB::getFPLibCall( |
| 1964 | + Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64, |
| 1965 | + Call_F80, Call_F128, Call_PPCF128); |
| 1966 | + |
| 1967 | + if (LC == RTLIB::UNKNOWN_LIBCALL) |
| 1968 | + return false; |
| 1969 | + |
| 1970 | + return tryExpandVecMathCall(Node, LC, Results); |
| 1971 | +} |
| 1972 | + |
1845 | 1973 | void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
|
1846 | 1974 | SmallVectorImpl<SDValue> &Results) {
|
1847 | 1975 | EVT VT = Node->getValueType(0);
|
|
0 commit comments