Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3507,6 +3507,7 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
SDValue Op0 = GetSoftPromotedHalf(N->getOperand(0));
SDValue Op1 = GetSoftPromotedHalf(N->getOperand(1));
SDValue Op2 = GetSoftPromotedHalf(N->getOperand(2));
SDNodeFlags Flags = N->getFlags();
SDLoc dl(N);

// Promote to the larger FP type.
Expand All @@ -3515,9 +3516,28 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FMAD(SDNode *N) {
Op1 = DAG.getNode(PromotionOpcode, dl, NVT, Op1);
Op2 = DAG.getNode(PromotionOpcode, dl, NVT, Op2);

SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2);
SDValue Res;
if (OVT == MVT::f16) {
// If f16 fma is not natively supported, the value must be promoted to an
// f64 (and not to f32!) to prevent double rounding issues.
SDValue A64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op0, Flags);
SDValue B64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op1, Flags);
SDValue C64 = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f64, Op2, Flags);

// Prefer a wide FMA node if available; otherwise expand to mul+add.
SDValue WideRes;
if (TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), MVT::f64)) {
WideRes = DAG.getNode(ISD::FMA, dl, MVT::f64, A64, B64, C64, Flags);
} else {
SDValue Mul = DAG.getNode(ISD::FMUL, dl, MVT::f64, A64, B64, Flags);
WideRes = DAG.getNode(ISD::FADD, dl, MVT::f64, Mul, C64, Flags);
}

// Convert back to FP16 as an integer.
return DAG.getNode(GetPromotionOpcode(MVT::f64, OVT), dl, MVT::i16,
WideRes);
}

Res = DAG.getNode(N->getOpcode(), dl, NVT, Op0, Op1, Op2, Flags);
return DAG.getNode(GetPromotionOpcode(NVT, OVT), dl, MVT::i16, Res);
}

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,11 @@ void TargetLoweringBase::initActions() {
}
}

// If f16 fma is not natively supported, the value must be promoted to an f64
// (and not to f32!) to prevent double rounding issues.
AddPromotedToType(ISD::FMA, MVT::f16, MVT::f64);
AddPromotedToType(ISD::STRICT_FMA, MVT::f16, MVT::f64);

// Set default actions for various operations.
for (MVT VT : MVT::all_valuetypes()) {
// Default all indexed load / store to expand.
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, XLenVT, Custom);
}

if (!Subtarget.hasStdExtD()) {
// FIXME: handle f16 fma when f64 is not legal. Using an f32 fma
// instruction runs into double rounding issues, so this is wrong.
// Normally we'd use an f64 fma, but without the D extension the f64 type
// is not legal. This should probably be a libcall.
AddPromotedToType(ISD::FMA, MVT::f16, MVT::f32);
AddPromotedToType(ISD::STRICT_FMA, MVT::f16, MVT::f32);
}

setOperationAction(ISD::BITCAST, MVT::i16, Custom);

setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal);
Expand Down
10 changes: 5 additions & 5 deletions llvm/test/CodeGen/AArch64/f16-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1378,11 +1378,11 @@ define half @test_log2(half %a) #0 {
define half @test_fma(half %a, half %b, half %c) #0 {
; CHECK-CVT-SD-LABEL: test_fma:
; CHECK-CVT-SD: // %bb.0:
; CHECK-CVT-SD-NEXT: fcvt s2, h2
; CHECK-CVT-SD-NEXT: fcvt s1, h1
; CHECK-CVT-SD-NEXT: fcvt s0, h0
; CHECK-CVT-SD-NEXT: fmadd s0, s0, s1, s2
; CHECK-CVT-SD-NEXT: fcvt h0, s0
; CHECK-CVT-SD-NEXT: fcvt d2, h2
; CHECK-CVT-SD-NEXT: fcvt d1, h1
; CHECK-CVT-SD-NEXT: fcvt d0, h0
; CHECK-CVT-SD-NEXT: fmadd d0, d0, d1, d2
; CHECK-CVT-SD-NEXT: fcvt h0, d0
; CHECK-CVT-SD-NEXT: ret
;
; CHECK-FP16-LABEL: test_fma:
Expand Down
Loading