Skip to content

Commit 915139f

Browse files
committed
[X86][BF16] Lower FP_EXTEND for vector types under AVX512BF16
Fixes #64460 Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D158950
1 parent 2602157 commit 915139f

File tree

3 files changed

+510
-161
lines changed

3 files changed

+510
-161
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,17 +1613,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
16131613
setOperationAction(ISD::FP_ROUND, VT, Custom);
16141614
setOperationAction(ISD::STRICT_FP_ROUND, VT, Custom);
16151615
}
1616-
for (MVT VT : { MVT::f32, MVT::v2f32, MVT::v4f32 }) {
1616+
for (MVT VT : { MVT::f32, MVT::v2f32, MVT::v4f32, MVT::v8f32 }) {
16171617
setOperationAction(ISD::FP_EXTEND, VT, Custom);
16181618
setOperationAction(ISD::STRICT_FP_EXTEND, VT, Custom);
16191619
}
16201620
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
16211621
setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
16221622
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
16231623
}
1624-
1625-
setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Legal);
1626-
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f32, Legal);
16271624
}
16281625

16291626
// This block controls legalization of the mask vector sizes that are
@@ -1940,8 +1937,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
19401937
setF16Action(MVT::v32f16, Expand);
19411938
setOperationAction(ISD::FP_ROUND, MVT::v16f16, Custom);
19421939
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v16f16, Custom);
1943-
setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Legal);
1944-
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Legal);
1940+
setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Custom);
1941+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
19451942
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
19461943
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
19471944
setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
@@ -2162,9 +2159,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
21622159
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v32i16, Legal);
21632160
setOperationAction(ISD::FP_ROUND, MVT::v16f16, Legal);
21642161
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v16f16, Legal);
2165-
setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Legal);
2162+
setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Custom);
21662163
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Legal);
2167-
setOperationAction(ISD::FP_EXTEND, MVT::v8f64, Legal);
2164+
setOperationAction(ISD::FP_EXTEND, MVT::v8f64, Custom);
21682165
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f64, Legal);
21692166
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v32f16, Custom);
21702167

@@ -2214,9 +2211,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22142211
setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v8i16, Custom);
22152212
setOperationAction(ISD::FP_ROUND, MVT::v8f16, Legal);
22162213
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v8f16, Legal);
2217-
setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Legal);
2214+
setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Custom);
22182215
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v8f32, Legal);
2219-
setOperationAction(ISD::FP_EXTEND, MVT::v4f64, Legal);
2216+
setOperationAction(ISD::FP_EXTEND, MVT::v4f64, Custom);
22202217
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f64, Legal);
22212218

22222219
// INSERT_VECTOR_ELT v8f16 extended to VECTOR_SHUFFLE
@@ -11914,13 +11911,9 @@ static bool isShuffleFoldableLoad(SDValue V) {
1191411911
}
1191511912

1191611913
template<typename T>
11917-
static bool isSoftFP16(T VT, const X86Subtarget &Subtarget) {
11918-
return VT.getScalarType() == MVT::f16 && !Subtarget.hasFP16();
11919-
}
11920-
11921-
template<typename T>
11922-
bool X86TargetLowering::isSoftFP16(T VT) const {
11923-
return ::isSoftFP16(VT, Subtarget);
11914+
static bool isSoftF16(T VT, const X86Subtarget &Subtarget) {
11915+
T EltVT = VT.getScalarType();
11916+
return EltVT == MVT::bf16 || (EltVT == MVT::f16 && !Subtarget.hasFP16());
1192411917
}
1192511918

1192611919
/// Try to lower insertion of a single element into a zero vector.
@@ -11936,7 +11929,7 @@ static SDValue lowerShuffleAsElementInsertion(
1193611929
unsigned NumElts = VT.getVectorNumElements();
1193711930
unsigned EltBits = VT.getScalarSizeInBits();
1193811931

11939-
if (isSoftFP16(EltVT, Subtarget))
11932+
if (isSoftF16(EltVT, Subtarget))
1194011933
return SDValue();
1194111934

1194211935
int V2Index =
@@ -17491,7 +17484,7 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const {
1749117484

1749217485
SDLoc dl(Op);
1749317486
MVT VT = Op.getSimpleValueType();
17494-
if (isSoftFP16(VT)) {
17487+
if (isSoftF16(VT, Subtarget)) {
1749517488
MVT NVT = VT.changeVectorElementTypeToInteger();
1749617489
return DAG.getBitcast(VT, DAG.getNode(ISD::VSELECT, dl, NVT, Cond,
1749717490
DAG.getBitcast(NVT, LHS),
@@ -19019,7 +19012,7 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
1901919012
MVT VT = Op.getSimpleValueType();
1902019013
SDLoc dl(Op);
1902119014

19022-
if (isSoftFP16(VT))
19015+
if (isSoftF16(VT, Subtarget))
1902319016
return promoteXINT_TO_FP(Op, DAG);
1902419017
else if (isLegalConversion(SrcVT, true, Subtarget))
1902519018
return Op;
@@ -19524,7 +19517,7 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
1952419517
if (DstVT == MVT::f128)
1952519518
return SDValue();
1952619519

19527-
if (isSoftFP16(DstVT))
19520+
if (isSoftF16(DstVT, Subtarget))
1952819521
return promoteXINT_TO_FP(Op, DAG);
1952919522
else if (isLegalConversion(SrcVT, false, Subtarget))
1953019523
return Op;
@@ -20543,7 +20536,7 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2054320536
SDLoc dl(Op);
2054420537

2054520538
SDValue Res;
20546-
if (isSoftFP16(SrcVT)) {
20539+
if (isSoftF16(SrcVT, Subtarget)) {
2054720540
MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
2054820541
if (IsStrict)
2054920542
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
@@ -20972,7 +20965,7 @@ X86TargetLowering::LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const {
2097220965

2097320966
// This code is only for floats and doubles. Fall back to generic code for
2097420967
// anything else.
20975-
if (!isScalarFPTypeInSSEReg(SrcVT) || isSoftFP16(SrcVT))
20968+
if (!isScalarFPTypeInSSEReg(SrcVT) || isSoftF16(SrcVT, Subtarget))
2097620969
return SDValue();
2097720970

2097820971
EVT SatVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
@@ -21117,6 +21110,10 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
2111721110
!Subtarget.getTargetTriple().isOSDarwin()))
2111821111
return SDValue();
2111921112

21113+
if ((SVT == MVT::v8f16 && Subtarget.hasF16C()) ||
21114+
(SVT == MVT::v16f16 && Subtarget.useAVX512Regs()))
21115+
return Op;
21116+
2112021117
if (SVT == MVT::f16) {
2112121118
if (Subtarget.hasFP16())
2112221119
return Op;
@@ -21189,7 +21186,25 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
2118921186
if (!SVT.isVector())
2119021187
return Op;
2119121188

21189+
if (SVT.getVectorElementType() == MVT::bf16) {
21190+
// FIXME: Do we need to support strict FP?
21191+
assert(!IsStrict && "Strict FP doesn't support BF16");
21192+
if (VT.getVectorElementType() == MVT::f64) {
21193+
MVT TmpVT = VT.changeVectorElementType(MVT::f32);
21194+
return DAG.getNode(ISD::FP_EXTEND, DL, VT,
21195+
DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
21196+
}
21197+
assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
21198+
MVT NVT = SVT.changeVectorElementType(MVT::i32);
21199+
In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
21200+
In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
21201+
In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(8, DL, NVT));
21202+
return DAG.getBitcast(VT, In);
21203+
}
21204+
2119221205
if (SVT.getVectorElementType() == MVT::f16) {
21206+
if (Subtarget.hasFP16() && isTypeLegal(SVT))
21207+
return Op;
2119321208
assert(Subtarget.hasF16C() && "Unexpected features!");
2119421209
if (SVT == MVT::v2f16)
2119521210
In = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f16, In,
@@ -22910,7 +22925,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2291022925
if (isFP) {
2291122926
MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
2291222927
assert(EltVT == MVT::f16 || EltVT == MVT::f32 || EltVT == MVT::f64);
22913-
if (isSoftFP16(EltVT, Subtarget))
22928+
if (isSoftF16(EltVT, Subtarget))
2291422929
return SDValue();
2291522930

2291622931
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
@@ -23475,7 +23490,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
2347523490
ISD::CondCode CC =
2347623491
cast<CondCodeSDNode>(Op.getOperand(IsStrict ? 3 : 2))->get();
2347723492

23478-
if (isSoftFP16(Op0.getValueType()))
23493+
if (isSoftF16(Op0.getValueType(), Subtarget))
2347923494
return SDValue();
2348023495

2348123496
// Handle f128 first, since one possible outcome is a normal integer
@@ -23668,7 +23683,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
2366823683
MVT VT = Op1.getSimpleValueType();
2366923684
SDValue CC;
2367023685

23671-
if (isSoftFP16(VT)) {
23686+
if (isSoftF16(VT, Subtarget)) {
2367223687
MVT NVT = VT.changeTypeToInteger();
2367323688
return DAG.getBitcast(VT, DAG.getNode(ISD::SELECT, DL, NVT, Cond,
2367423689
DAG.getBitcast(NVT, Op1),
@@ -23740,7 +23755,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
2374023755
}
2374123756

2374223757
if (Cond.getOpcode() == ISD::SETCC &&
23743-
!isSoftFP16(Cond.getOperand(0).getSimpleValueType())) {
23758+
!isSoftF16(Cond.getOperand(0).getSimpleValueType(), Subtarget)) {
2374423759
if (SDValue NewCond = LowerSETCC(Cond, DAG)) {
2374523760
Cond = NewCond;
2374623761
// If the condition was updated, it's possible that the operands of the
@@ -24430,7 +24445,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
2443024445
// Bail out when we don't have native compare instructions.
2443124446
if (Cond.getOpcode() == ISD::SETCC &&
2443224447
Cond.getOperand(0).getValueType() != MVT::f128 &&
24433-
!isSoftFP16(Cond.getOperand(0).getValueType())) {
24448+
!isSoftF16(Cond.getOperand(0).getValueType(), Subtarget)) {
2443424449
SDValue LHS = Cond.getOperand(0);
2443524450
SDValue RHS = Cond.getOperand(1);
2443624451
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
@@ -32231,7 +32246,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
3223132246
EVT SrcVT = Src.getValueType();
3223232247

3223332248
SDValue Res;
32234-
if (isSoftFP16(SrcVT)) {
32249+
if (isSoftF16(SrcVT, Subtarget)) {
3223532250
EVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
3223632251
if (IsStrict) {
3223732252
Res =
@@ -44636,7 +44651,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4463644651
// ignored in unsafe-math mode).
4463744652
// We also try to create v2f32 min/max nodes, which we later widen to v4f32.
4463844653
if (Cond.getOpcode() == ISD::SETCC && VT.isFloatingPoint() &&
44639-
VT != MVT::f80 && VT != MVT::f128 && !isSoftFP16(VT, Subtarget) &&
44654+
VT != MVT::f80 && VT != MVT::f128 && !isSoftF16(VT, Subtarget) &&
4464044655
(TLI.isTypeLegal(VT) || VT == MVT::v2f32) &&
4464144656
(Subtarget.hasSSE2() ||
4464244657
(Subtarget.hasSSE1() && VT.getScalarType() == MVT::f32))) {
@@ -44953,7 +44968,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4495344968
}
4495444969

4495544970
// Early exit check
44956-
if (!TLI.isTypeLegal(VT) || isSoftFP16(VT, Subtarget))
44971+
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
4495744972
return SDValue();
4495844973

4495944974
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget))
@@ -51712,7 +51727,7 @@ static SDValue combineFMinFMax(SDNode *N, SelectionDAG &DAG) {
5171251727
static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
5171351728
const X86Subtarget &Subtarget) {
5171451729
EVT VT = N->getValueType(0);
51715-
if (Subtarget.useSoftFloat() || isSoftFP16(VT, Subtarget))
51730+
if (Subtarget.useSoftFloat() || isSoftF16(VT, Subtarget))
5171651731
return SDValue();
5171751732

5171851733
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,8 +1755,6 @@ namespace llvm {
17551755

17561756
bool needsCmpXchgNb(Type *MemType) const;
17571757

1758-
template<typename T> bool isSoftFP16(T VT) const;
1759-
17601758
void SetupEntryBlockForSjLj(MachineInstr &MI, MachineBasicBlock *MBB,
17611759
MachineBasicBlock *DispatchBB, int FI) const;
17621760

0 commit comments

Comments
 (0)