Skip to content

Commit 98c90a1

Browse files
authored
ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering (#66924)
The issue #55208 noticed that std::rint is vectorized by the SLPVectorizer, but a very similar function, std::lrint, is not. std::lrint corresponds to ISD::LRINT in the SelectionDAG, and std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now, neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant, and the LangRef makes this clear in the documentation of llvm.lrint.* and llvm.llrint.*. This patch extends the LangRef to include vector variants of llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of scalarizing it for all targets. However, this patch would be devoid of motivation unless we show the utility of these new vector variants. Hence, the RISCV target has been chosen to implement a custom lowering to the vfcvt.x.f.v instruction. The patch also includes a CostModel for RISCV, and a trivial follow-up can potentially enable the SLPVectorizer to vectorize std::lrint and std::llrint, fixing #55208. The patch includes tests, obviously for the RISCV target, but also for the X86, AArch64, and PowerPC targets to justify the addition of the vector variants to the LangRef.
1 parent 3d7802d commit 98c90a1

21 files changed

+12200
-15
lines changed

llvm/docs/LangRef.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15760,7 +15760,8 @@ Syntax:
1576015760
"""""""
1576115761

1576215762
This is an overloaded intrinsic. You can use ``llvm.lrint`` on any
15763-
floating-point type. Not all targets support all types however.
15763+
floating-point type or vector of floating-point type. Not all targets
15764+
support all types however.
1576415765

1576515766
::
1576615767

@@ -15804,7 +15805,8 @@ Syntax:
1580415805
"""""""
1580515806

1580615807
This is an overloaded intrinsic. You can use ``llvm.llrint`` on any
15807-
floating-point type. Not all targets support all types however.
15808+
floating-point type or vector of floating-point type. Not all targets
15809+
support all types however.
1580815810

1580915811
::
1581015812

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
18471847
case Intrinsic::rint:
18481848
ISD = ISD::FRINT;
18491849
break;
1850+
case Intrinsic::lrint:
1851+
ISD = ISD::LRINT;
1852+
break;
1853+
case Intrinsic::llrint:
1854+
ISD = ISD::LLRINT;
1855+
break;
18501856
case Intrinsic::round:
18511857
ISD = ISD::FROUND;
18521858
break;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ namespace {
505505
SDValue visitUINT_TO_FP(SDNode *N);
506506
SDValue visitFP_TO_SINT(SDNode *N);
507507
SDValue visitFP_TO_UINT(SDNode *N);
508+
SDValue visitXRINT(SDNode *N);
508509
SDValue visitFP_ROUND(SDNode *N);
509510
SDValue visitFP_EXTEND(SDNode *N);
510511
SDValue visitFNEG(SDNode *N);
@@ -1911,6 +1912,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
19111912
}
19121913

19131914
SDValue DAGCombiner::visit(SDNode *N) {
1915+
// clang-format off
19141916
switch (N->getOpcode()) {
19151917
default: break;
19161918
case ISD::TokenFactor: return visitTokenFactor(N);
@@ -2011,6 +2013,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
20112013
case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
20122014
case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
20132015
case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
2016+
case ISD::LRINT:
2017+
case ISD::LLRINT: return visitXRINT(N);
20142018
case ISD::FP_ROUND: return visitFP_ROUND(N);
20152019
case ISD::FP_EXTEND: return visitFP_EXTEND(N);
20162020
case ISD::FNEG: return visitFNEG(N);
@@ -2065,6 +2069,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
20652069
#include "llvm/IR/VPIntrinsics.def"
20662070
return visitVPOp(N);
20672071
}
2072+
// clang-format on
20682073
return SDValue();
20692074
}
20702075

@@ -17480,6 +17485,21 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
1748017485
return FoldIntToFPToInt(N, DAG);
1748117486
}
1748217487

17488+
SDValue DAGCombiner::visitXRINT(SDNode *N) {
17489+
SDValue N0 = N->getOperand(0);
17490+
EVT VT = N->getValueType(0);
17491+
17492+
// fold (lrint|llrint undef) -> undef
17493+
if (N0.isUndef())
17494+
return DAG.getUNDEF(VT);
17495+
17496+
// fold (lrint|llrint c1fp) -> c1
17497+
if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17498+
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
17499+
17500+
return SDValue();
17501+
}
17502+
1748317503
SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
1748417504
SDValue N0 = N->getOperand(0);
1748517505
SDValue N1 = N->getOperand(1);

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,6 +2198,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
21982198
// to use the promoted float operand. Nodes that produce at least one
21992199
// promotion-requiring floating point result have their operands legalized as
22002200
// a part of PromoteFloatResult.
2201+
// clang-format off
22012202
switch (N->getOpcode()) {
22022203
default:
22032204
#ifndef NDEBUG
@@ -2209,7 +2210,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
22092210
case ISD::BITCAST: R = PromoteFloatOp_BITCAST(N, OpNo); break;
22102211
case ISD::FCOPYSIGN: R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
22112212
case ISD::FP_TO_SINT:
2212-
case ISD::FP_TO_UINT: R = PromoteFloatOp_FP_TO_XINT(N, OpNo); break;
2213+
case ISD::FP_TO_UINT:
2214+
case ISD::LRINT:
2215+
case ISD::LLRINT: R = PromoteFloatOp_UnaryOp(N, OpNo); break;
22132216
case ISD::FP_TO_SINT_SAT:
22142217
case ISD::FP_TO_UINT_SAT:
22152218
R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
@@ -2218,6 +2221,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
22182221
case ISD::SETCC: R = PromoteFloatOp_SETCC(N, OpNo); break;
22192222
case ISD::STORE: R = PromoteFloatOp_STORE(N, OpNo); break;
22202223
}
2224+
// clang-format on
22212225

22222226
if (R.getNode())
22232227
ReplaceValueWith(SDValue(N, 0), R);
@@ -2251,7 +2255,7 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo) {
22512255
}
22522256

22532257
// Convert the promoted float value to the desired integer type
2254-
SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo) {
2258+
SDValue DAGTypeLegalizer::PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo) {
22552259
SDValue Op = GetPromotedFloat(N->getOperand(0));
22562260
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op);
22572261
}

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
711711
SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
712712
SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
713713
SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
714-
SDValue PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo);
714+
SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
715715
SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
716716
SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
717717
SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
402402
case ISD::FCEIL:
403403
case ISD::FTRUNC:
404404
case ISD::FRINT:
405+
case ISD::LRINT:
406+
case ISD::LLRINT:
405407
case ISD::FNEARBYINT:
406408
case ISD::FROUND:
407409
case ISD::FROUNDEVEN:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
101101
case ISD::FP_TO_SINT:
102102
case ISD::FP_TO_UINT:
103103
case ISD::FRINT:
104+
case ISD::LRINT:
105+
case ISD::LLRINT:
104106
case ISD::FROUND:
105107
case ISD::FROUNDEVEN:
106108
case ISD::FSIN:
@@ -681,6 +683,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
681683
case ISD::FP_TO_UINT:
682684
case ISD::SINT_TO_FP:
683685
case ISD::UINT_TO_FP:
686+
case ISD::LRINT:
687+
case ISD::LLRINT:
684688
Res = ScalarizeVecOp_UnaryOp(N);
685689
break;
686690
case ISD::STRICT_SINT_TO_FP:
@@ -1097,6 +1101,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
10971101
case ISD::VP_FP_TO_UINT:
10981102
case ISD::FRINT:
10991103
case ISD::VP_FRINT:
1104+
case ISD::LRINT:
1105+
case ISD::LLRINT:
11001106
case ISD::FROUND:
11011107
case ISD::VP_FROUND:
11021108
case ISD::FROUNDEVEN:
@@ -2974,6 +2980,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
29742980
case ISD::ZERO_EXTEND:
29752981
case ISD::ANY_EXTEND:
29762982
case ISD::FTRUNC:
2983+
case ISD::LRINT:
2984+
case ISD::LLRINT:
29772985
Res = SplitVecOp_UnaryOp(N);
29782986
break;
29792987
case ISD::FLDEXP:
@@ -4209,6 +4217,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
42094217
case ISD::FLOG2:
42104218
case ISD::FNEARBYINT:
42114219
case ISD::FRINT:
4220+
case ISD::LRINT:
4221+
case ISD::LLRINT:
42124222
case ISD::FROUND:
42134223
case ISD::FROUNDEVEN:
42144224
case ISD::FSIN:
@@ -5958,7 +5968,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
59585968
case ISD::STRICT_FSETCCS: Res = WidenVecOp_STRICT_FSETCC(N); break;
59595969
case ISD::VSELECT: Res = WidenVecOp_VSELECT(N); break;
59605970
case ISD::FLDEXP:
5961-
case ISD::FCOPYSIGN: Res = WidenVecOp_UnrollVectorOp(N); break;
5971+
case ISD::FCOPYSIGN:
5972+
case ISD::LRINT:
5973+
case ISD::LLRINT:
5974+
Res = WidenVecOp_UnrollVectorOp(N);
5975+
break;
59625976
case ISD::IS_FPCLASS: Res = WidenVecOp_IS_FPCLASS(N); break;
59635977

59645978
case ISD::ANY_EXTEND:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5135,6 +5135,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
51355135
case ISD::FROUND:
51365136
case ISD::FROUNDEVEN:
51375137
case ISD::FRINT:
5138+
case ISD::LRINT:
5139+
case ISD::LLRINT:
51385140
case ISD::FNEARBYINT:
51395141
case ISD::FLDEXP: {
51405142
if (SNaN)

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -873,13 +873,13 @@ void TargetLoweringBase::initActions() {
873873

874874
// These operations default to expand for vector types.
875875
if (VT.isVector())
876-
setOperationAction({ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG,
877-
ISD::ANY_EXTEND_VECTOR_INREG,
878-
ISD::SIGN_EXTEND_VECTOR_INREG,
879-
ISD::ZERO_EXTEND_VECTOR_INREG, ISD::SPLAT_VECTOR},
880-
VT, Expand);
876+
setOperationAction(
877+
{ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
878+
ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
879+
ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT},
880+
VT, Expand);
881881

882-
// Constrained floating-point operations default to expand.
882+
// Constrained floating-point operations default to expand.
883883
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
884884
setOperationAction(ISD::STRICT_##DAGN, VT, Expand);
885885
#include "llvm/IR/ConstrainedOps.def"

llvm/lib/IR/Verifier.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5669,10 +5669,28 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
56695669
}
56705670
break;
56715671
}
5672-
case Intrinsic::lround:
5673-
case Intrinsic::llround:
56745672
case Intrinsic::lrint:
56755673
case Intrinsic::llrint: {
5674+
Type *ValTy = Call.getArgOperand(0)->getType();
5675+
Type *ResultTy = Call.getType();
5676+
Check(
5677+
ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
5678+
"llvm.lrint, llvm.llrint: argument must be floating-point or vector "
5679+
"of floating-points, and result must be integer or vector of integers",
5680+
&Call);
5681+
Check(ValTy->isVectorTy() == ResultTy->isVectorTy(),
5682+
"llvm.lrint, llvm.llrint: argument and result disagree on vector use",
5683+
&Call);
5684+
if (ValTy->isVectorTy()) {
5685+
Check(cast<VectorType>(ValTy)->getElementCount() ==
5686+
cast<VectorType>(ResultTy)->getElementCount(),
5687+
"llvm.lrint, llvm.llrint: argument must be same length as result",
5688+
&Call);
5689+
}
5690+
break;
5691+
}
5692+
case Intrinsic::lround:
5693+
case Intrinsic::llround: {
56765694
Type *ValTy = Call.getArgOperand(0)->getType();
56775695
Type *ResultTy = Call.getType();
56785696
Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
731731
VT, Custom);
732732
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
733733
Custom);
734-
734+
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
735735
setOperationAction(
736736
{ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);
737737

@@ -2950,6 +2950,31 @@ lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
29502950
DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT()));
29512951
}
29522952

2953+
// Expand vector LRINT and LLRINT by converting to the integer domain.
2954+
static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG,
2955+
const RISCVSubtarget &Subtarget) {
2956+
MVT VT = Op.getSimpleValueType();
2957+
assert(VT.isVector() && "Unexpected type");
2958+
2959+
SDLoc DL(Op);
2960+
SDValue Src = Op.getOperand(0);
2961+
MVT ContainerVT = VT;
2962+
2963+
if (VT.isFixedLengthVector()) {
2964+
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
2965+
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
2966+
}
2967+
2968+
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
2969+
SDValue Truncated =
2970+
DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, ContainerVT, Src, Mask, VL);
2971+
2972+
if (!VT.isFixedLengthVector())
2973+
return Truncated;
2974+
2975+
return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
2976+
}
2977+
29532978
static SDValue
29542979
getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget,
29552980
const SDLoc &DL, EVT VT, SDValue Merge, SDValue Op,
@@ -5978,6 +6003,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
59786003
case ISD::FROUND:
59796004
case ISD::FROUNDEVEN:
59806005
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
6006+
case ISD::LRINT:
6007+
case ISD::LLRINT:
6008+
return lowerVectorXRINT(Op, DAG, Subtarget);
59816009
case ISD::VECREDUCE_ADD:
59826010
case ISD::VECREDUCE_UMAX:
59836011
case ISD::VECREDUCE_SMAX:

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,31 @@ static const CostTblEntry VectorIntrinsicCostTable[]{
668668
{Intrinsic::rint, MVT::nxv2f64, 7},
669669
{Intrinsic::rint, MVT::nxv4f64, 7},
670670
{Intrinsic::rint, MVT::nxv8f64, 7},
671+
{Intrinsic::lrint, MVT::v2i32, 1},
672+
{Intrinsic::lrint, MVT::v4i32, 1},
673+
{Intrinsic::lrint, MVT::v8i32, 1},
674+
{Intrinsic::lrint, MVT::v16i32, 1},
675+
{Intrinsic::lrint, MVT::nxv1i32, 1},
676+
{Intrinsic::lrint, MVT::nxv2i32, 1},
677+
{Intrinsic::lrint, MVT::nxv4i32, 1},
678+
{Intrinsic::lrint, MVT::nxv8i32, 1},
679+
{Intrinsic::lrint, MVT::nxv16i32, 1},
680+
{Intrinsic::lrint, MVT::v2i64, 1},
681+
{Intrinsic::lrint, MVT::v4i64, 1},
682+
{Intrinsic::lrint, MVT::v8i64, 1},
683+
{Intrinsic::lrint, MVT::v16i64, 1},
684+
{Intrinsic::lrint, MVT::nxv1i64, 1},
685+
{Intrinsic::lrint, MVT::nxv2i64, 1},
686+
{Intrinsic::lrint, MVT::nxv4i64, 1},
687+
{Intrinsic::lrint, MVT::nxv8i64, 1},
688+
{Intrinsic::llrint, MVT::v2i64, 1},
689+
{Intrinsic::llrint, MVT::v4i64, 1},
690+
{Intrinsic::llrint, MVT::v8i64, 1},
691+
{Intrinsic::llrint, MVT::v16i64, 1},
692+
{Intrinsic::llrint, MVT::nxv1i64, 1},
693+
{Intrinsic::llrint, MVT::nxv2i64, 1},
694+
{Intrinsic::llrint, MVT::nxv4i64, 1},
695+
{Intrinsic::llrint, MVT::nxv8i64, 1},
671696
{Intrinsic::nearbyint, MVT::v2f32, 9},
672697
{Intrinsic::nearbyint, MVT::v4f32, 9},
673698
{Intrinsic::nearbyint, MVT::v8f32, 9},
@@ -1051,6 +1076,8 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10511076
case Intrinsic::floor:
10521077
case Intrinsic::trunc:
10531078
case Intrinsic::rint:
1079+
case Intrinsic::lrint:
1080+
case Intrinsic::llrint:
10541081
case Intrinsic::round:
10551082
case Intrinsic::roundeven: {
10561083
// These all use the same code.

0 commit comments

Comments
 (0)