@@ -1613,17 +1613,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1613
1613
setOperationAction(ISD::FP_ROUND, VT, Custom);
1614
1614
setOperationAction(ISD::STRICT_FP_ROUND, VT, Custom);
1615
1615
}
1616
- for (MVT VT : { MVT::f32, MVT::v2f32, MVT::v4f32 }) {
1616
+ for (MVT VT : { MVT::f32, MVT::v2f32, MVT::v4f32, MVT::v8f32 }) {
1617
1617
setOperationAction(ISD::FP_EXTEND, VT, Custom);
1618
1618
setOperationAction(ISD::STRICT_FP_EXTEND, VT, Custom);
1619
1619
}
1620
1620
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
1621
1621
setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
1622
1622
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
1623
1623
}
1624
-
1625
- setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Legal);
1626
- setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f32, Legal);
1627
1624
}
1628
1625
1629
1626
// This block controls legalization of the mask vector sizes that are
@@ -1940,8 +1937,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1940
1937
setF16Action(MVT::v32f16, Expand);
1941
1938
setOperationAction(ISD::FP_ROUND, MVT::v16f16, Custom);
1942
1939
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v16f16, Custom);
1943
- setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Legal );
1944
- setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Legal );
1940
+ setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Custom );
1941
+ setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom );
1945
1942
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
1946
1943
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
1947
1944
setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
@@ -2162,9 +2159,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2162
2159
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v32i16, Legal);
2163
2160
setOperationAction(ISD::FP_ROUND, MVT::v16f16, Legal);
2164
2161
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v16f16, Legal);
2165
- setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Legal );
2162
+ setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Custom );
2166
2163
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Legal);
2167
- setOperationAction(ISD::FP_EXTEND, MVT::v8f64, Legal );
2164
+ setOperationAction(ISD::FP_EXTEND, MVT::v8f64, Custom );
2168
2165
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f64, Legal);
2169
2166
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v32f16, Custom);
2170
2167
@@ -2214,9 +2211,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2214
2211
setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v8i16, Custom);
2215
2212
setOperationAction(ISD::FP_ROUND, MVT::v8f16, Legal);
2216
2213
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v8f16, Legal);
2217
- setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Legal );
2214
+ setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Custom );
2218
2215
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f32, Legal);
2219
- setOperationAction(ISD::FP_EXTEND, MVT::v4f64, Legal );
2216
+ setOperationAction(ISD::FP_EXTEND, MVT::v4f64, Custom );
2220
2217
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f64, Legal);
2221
2218
2222
2219
// INSERT_VECTOR_ELT v8f16 extended to VECTOR_SHUFFLE
@@ -11914,13 +11911,9 @@ static bool isShuffleFoldableLoad(SDValue V) {
11914
11911
}
11915
11912
11916
11913
template<typename T>
11917
- static bool isSoftFP16(T VT, const X86Subtarget &Subtarget) {
11918
- return VT.getScalarType() == MVT::f16 && !Subtarget.hasFP16();
11919
- }
11920
-
11921
- template<typename T>
11922
- bool X86TargetLowering::isSoftFP16(T VT) const {
11923
- return ::isSoftFP16(VT, Subtarget);
11914
+ static bool isSoftF16(T VT, const X86Subtarget &Subtarget) {
11915
+ T EltVT = VT.getScalarType();
11916
+ return EltVT == MVT::bf16 || (EltVT == MVT::f16 && !Subtarget.hasFP16());
11924
11917
}
11925
11918
11926
11919
/// Try to lower insertion of a single element into a zero vector.
@@ -11936,7 +11929,7 @@ static SDValue lowerShuffleAsElementInsertion(
11936
11929
unsigned NumElts = VT.getVectorNumElements();
11937
11930
unsigned EltBits = VT.getScalarSizeInBits();
11938
11931
11939
- if (isSoftFP16 (EltVT, Subtarget))
11932
+ if (isSoftF16 (EltVT, Subtarget))
11940
11933
return SDValue();
11941
11934
11942
11935
int V2Index =
@@ -17491,7 +17484,7 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const {
17491
17484
17492
17485
SDLoc dl(Op);
17493
17486
MVT VT = Op.getSimpleValueType();
17494
- if (isSoftFP16 (VT)) {
17487
+ if (isSoftF16 (VT, Subtarget )) {
17495
17488
MVT NVT = VT.changeVectorElementTypeToInteger();
17496
17489
return DAG.getBitcast(VT, DAG.getNode(ISD::VSELECT, dl, NVT, Cond,
17497
17490
DAG.getBitcast(NVT, LHS),
@@ -19019,7 +19012,7 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
19019
19012
MVT VT = Op.getSimpleValueType();
19020
19013
SDLoc dl(Op);
19021
19014
19022
- if (isSoftFP16 (VT))
19015
+ if (isSoftF16 (VT, Subtarget ))
19023
19016
return promoteXINT_TO_FP(Op, DAG);
19024
19017
else if (isLegalConversion(SrcVT, true, Subtarget))
19025
19018
return Op;
@@ -19524,7 +19517,7 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
19524
19517
if (DstVT == MVT::f128)
19525
19518
return SDValue();
19526
19519
19527
- if (isSoftFP16 (DstVT))
19520
+ if (isSoftF16 (DstVT, Subtarget ))
19528
19521
return promoteXINT_TO_FP(Op, DAG);
19529
19522
else if (isLegalConversion(SrcVT, false, Subtarget))
19530
19523
return Op;
@@ -20543,7 +20536,7 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
20543
20536
SDLoc dl(Op);
20544
20537
20545
20538
SDValue Res;
20546
- if (isSoftFP16 (SrcVT)) {
20539
+ if (isSoftF16 (SrcVT, Subtarget )) {
20547
20540
MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
20548
20541
if (IsStrict)
20549
20542
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
@@ -20972,7 +20965,7 @@ X86TargetLowering::LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const {
20972
20965
20973
20966
// This code is only for floats and doubles. Fall back to generic code for
20974
20967
// anything else.
20975
- if (!isScalarFPTypeInSSEReg(SrcVT) || isSoftFP16 (SrcVT))
20968
+ if (!isScalarFPTypeInSSEReg(SrcVT) || isSoftF16 (SrcVT, Subtarget ))
20976
20969
return SDValue();
20977
20970
20978
20971
EVT SatVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
@@ -21117,6 +21110,10 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
21117
21110
!Subtarget.getTargetTriple().isOSDarwin()))
21118
21111
return SDValue();
21119
21112
21113
+ if ((SVT == MVT::v8f16 && Subtarget.hasF16C()) ||
21114
+ (SVT == MVT::v16f16 && Subtarget.useAVX512Regs()))
21115
+ return Op;
21116
+
21120
21117
if (SVT == MVT::f16) {
21121
21118
if (Subtarget.hasFP16())
21122
21119
return Op;
@@ -21189,7 +21186,25 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
21189
21186
if (!SVT.isVector())
21190
21187
return Op;
21191
21188
21189
+ if (SVT.getVectorElementType() == MVT::bf16) {
21190
+ // FIXME: Do we need to support strict FP?
21191
+ assert(!IsStrict && "Strict FP doesn't support BF16");
21192
+ if (VT.getVectorElementType() == MVT::f64) {
21193
+ MVT TmpVT = VT.changeVectorElementType(MVT::f32);
21194
+ return DAG.getNode(ISD::FP_EXTEND, DL, VT,
21195
+ DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
21196
+ }
21197
+ assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
21198
+ MVT NVT = SVT.changeVectorElementType(MVT::i32);
21199
+ In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
21200
+ In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
21201
+ In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(8, DL, NVT));
21202
+ return DAG.getBitcast(VT, In);
21203
+ }
21204
+
21192
21205
if (SVT.getVectorElementType() == MVT::f16) {
21206
+ if (Subtarget.hasFP16() && isTypeLegal(SVT))
21207
+ return Op;
21193
21208
assert(Subtarget.hasF16C() && "Unexpected features!");
21194
21209
if (SVT == MVT::v2f16)
21195
21210
In = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f16, In,
@@ -22910,7 +22925,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
22910
22925
if (isFP) {
22911
22926
MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
22912
22927
assert(EltVT == MVT::f16 || EltVT == MVT::f32 || EltVT == MVT::f64);
22913
- if (isSoftFP16 (EltVT, Subtarget))
22928
+ if (isSoftF16 (EltVT, Subtarget))
22914
22929
return SDValue();
22915
22930
22916
22931
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
@@ -23475,7 +23490,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
23475
23490
ISD::CondCode CC =
23476
23491
cast<CondCodeSDNode>(Op.getOperand(IsStrict ? 3 : 2))->get();
23477
23492
23478
- if (isSoftFP16 (Op0.getValueType()))
23493
+ if (isSoftF16 (Op0.getValueType(), Subtarget ))
23479
23494
return SDValue();
23480
23495
23481
23496
// Handle f128 first, since one possible outcome is a normal integer
@@ -23668,7 +23683,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
23668
23683
MVT VT = Op1.getSimpleValueType();
23669
23684
SDValue CC;
23670
23685
23671
- if (isSoftFP16 (VT)) {
23686
+ if (isSoftF16 (VT, Subtarget )) {
23672
23687
MVT NVT = VT.changeTypeToInteger();
23673
23688
return DAG.getBitcast(VT, DAG.getNode(ISD::SELECT, DL, NVT, Cond,
23674
23689
DAG.getBitcast(NVT, Op1),
@@ -23740,7 +23755,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
23740
23755
}
23741
23756
23742
23757
if (Cond.getOpcode() == ISD::SETCC &&
23743
- !isSoftFP16 (Cond.getOperand(0).getSimpleValueType())) {
23758
+ !isSoftF16 (Cond.getOperand(0).getSimpleValueType(), Subtarget )) {
23744
23759
if (SDValue NewCond = LowerSETCC(Cond, DAG)) {
23745
23760
Cond = NewCond;
23746
23761
// If the condition was updated, it's possible that the operands of the
@@ -24430,7 +24445,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
24430
24445
// Bail out when we don't have native compare instructions.
24431
24446
if (Cond.getOpcode() == ISD::SETCC &&
24432
24447
Cond.getOperand(0).getValueType() != MVT::f128 &&
24433
- !isSoftFP16 (Cond.getOperand(0).getValueType())) {
24448
+ !isSoftF16 (Cond.getOperand(0).getValueType(), Subtarget )) {
24434
24449
SDValue LHS = Cond.getOperand(0);
24435
24450
SDValue RHS = Cond.getOperand(1);
24436
24451
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
@@ -32231,7 +32246,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
32231
32246
EVT SrcVT = Src.getValueType();
32232
32247
32233
32248
SDValue Res;
32234
- if (isSoftFP16 (SrcVT)) {
32249
+ if (isSoftF16 (SrcVT, Subtarget )) {
32235
32250
EVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
32236
32251
if (IsStrict) {
32237
32252
Res =
@@ -44636,7 +44651,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
44636
44651
// ignored in unsafe-math mode).
44637
44652
// We also try to create v2f32 min/max nodes, which we later widen to v4f32.
44638
44653
if (Cond.getOpcode() == ISD::SETCC && VT.isFloatingPoint() &&
44639
- VT != MVT::f80 && VT != MVT::f128 && !isSoftFP16 (VT, Subtarget) &&
44654
+ VT != MVT::f80 && VT != MVT::f128 && !isSoftF16 (VT, Subtarget) &&
44640
44655
(TLI.isTypeLegal(VT) || VT == MVT::v2f32) &&
44641
44656
(Subtarget.hasSSE2() ||
44642
44657
(Subtarget.hasSSE1() && VT.getScalarType() == MVT::f32))) {
@@ -44953,7 +44968,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
44953
44968
}
44954
44969
44955
44970
// Early exit check
44956
- if (!TLI.isTypeLegal(VT) || isSoftFP16 (VT, Subtarget))
44971
+ if (!TLI.isTypeLegal(VT) || isSoftF16 (VT, Subtarget))
44957
44972
return SDValue();
44958
44973
44959
44974
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget))
@@ -51712,7 +51727,7 @@ static SDValue combineFMinFMax(SDNode *N, SelectionDAG &DAG) {
51712
51727
static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
51713
51728
const X86Subtarget &Subtarget) {
51714
51729
EVT VT = N->getValueType(0);
51715
- if (Subtarget.useSoftFloat() || isSoftFP16 (VT, Subtarget))
51730
+ if (Subtarget.useSoftFloat() || isSoftF16 (VT, Subtarget))
51716
51731
return SDValue();
51717
51732
51718
51733
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
0 commit comments