Skip to content

Commit 03de8c2

Browse files
committed
[SelectionDAG] Add STRICT_BF16_TO_FP and STRICT_FP_TO_BF16
This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`. Fix #78540.
1 parent 3e0425c commit 03de8c2

File tree

9 files changed

+164
-21
lines changed

9 files changed

+164
-21
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

+2
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ enum NodeType {
921921
/// has native conversions.
922922
BF16_TO_FP,
923923
FP_TO_BF16,
924+
STRICT_BF16_TO_FP,
925+
STRICT_FP_TO_BF16,
924926

925927
/// Perform various unary floating-point operations inspired by libm. For
926928
/// FPOWI, the result is undefined if the integer operand doesn't fit into

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

+2
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
698698
return false;
699699
case ISD::STRICT_FP16_TO_FP:
700700
case ISD::STRICT_FP_TO_FP16:
701+
case ISD::STRICT_BF16_TO_FP:
702+
case ISD::STRICT_FP_TO_BF16:
701703
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
702704
case ISD::STRICT_##DAGN:
703705
#include "llvm/IR/ConstrainedOps.def"

llvm/include/llvm/Target/TargetSelectionDAG.td

+13
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ def fp_to_sint_sat : SDNode<"ISD::FP_TO_SINT_SAT" , SDTFPToIntSatOp>;
541541
def fp_to_uint_sat : SDNode<"ISD::FP_TO_UINT_SAT" , SDTFPToIntSatOp>;
542542
def f16_to_fp : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>;
543543
def fp_to_f16 : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>;
544+
def bf16_to_fp : SDNode<"ISD::BF16_TO_FP" , SDTIntToFPOp>;
545+
def fp_to_bf16 : SDNode<"ISD::FP_TO_BF16" , SDTFPToIntOp>;
544546

545547
def strict_fadd : SDNode<"ISD::STRICT_FADD",
546548
SDTFPBinOp, [SDNPHasChain, SDNPCommutative]>;
@@ -620,6 +622,11 @@ def strict_f16_to_fp : SDNode<"ISD::STRICT_FP16_TO_FP",
620622
def strict_fp_to_f16 : SDNode<"ISD::STRICT_FP_TO_FP16",
621623
SDTFPToIntOp, [SDNPHasChain]>;
622624

625+
def strict_bf16_to_fp : SDNode<"ISD::STRICT_BF16_TO_FP",
626+
SDTIntToFPOp, [SDNPHasChain]>;
627+
def strict_fp_to_bf16 : SDNode<"ISD::STRICT_FP_TO_BF16",
628+
SDTFPToIntOp, [SDNPHasChain]>;
629+
623630
def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
624631
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
625632

@@ -1591,6 +1598,12 @@ def any_f16_to_fp : PatFrags<(ops node:$src),
15911598
def any_fp_to_f16 : PatFrags<(ops node:$src),
15921599
[(fp_to_f16 node:$src),
15931600
(strict_fp_to_f16 node:$src)]>;
1601+
def any_bf16_to_fp : PatFrags<(ops node:$src),
1602+
[(bf16_to_fp node:$src),
1603+
(strict_bf16_to_fp node:$src)]>;
1604+
def any_fp_to_bf16 : PatFrags<(ops node:$src),
1605+
[(fp_to_bf16 node:$src),
1606+
(strict_fp_to_bf16 node:$src)]>;
15941607

15951608
multiclass binary_atomic_op_ord {
15961609
def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
10471047
Node->getOperand(0).getValueType());
10481048
break;
10491049
case ISD::STRICT_FP_TO_FP16:
1050+
case ISD::STRICT_FP_TO_BF16:
10501051
case ISD::STRICT_SINT_TO_FP:
10511052
case ISD::STRICT_UINT_TO_FP:
10521053
case ISD::STRICT_LRINT:
@@ -3263,6 +3264,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32633264
Results.push_back(Tmp1);
32643265
break;
32653266
}
3267+
case ISD::STRICT_BF16_TO_FP:
3268+
// We don't support this expansion for now.
3269+
break;
32663270
case ISD::BF16_TO_FP: {
32673271
// Always expand bf16 to f32 casts, they lower to ext + shift.
32683272
//
@@ -3286,6 +3290,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32863290
Results.push_back(Op);
32873291
break;
32883292
}
3293+
case ISD::STRICT_FP_TO_BF16:
3294+
// We don't support this expansion for now.
3295+
break;
32893296
case ISD::FP_TO_BF16: {
32903297
SDValue Op = Node->getOperand(0);
32913298
if (Op.getValueType() != MVT::f32)
@@ -4792,12 +4799,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
47924799
break;
47934800
}
47944801
case ISD::STRICT_FP_EXTEND:
4795-
case ISD::STRICT_FP_TO_FP16: {
4796-
RTLIB::Libcall LC =
4797-
Node->getOpcode() == ISD::STRICT_FP_TO_FP16
4798-
? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
4799-
: RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4800-
Node->getValueType(0));
4802+
case ISD::STRICT_FP_TO_FP16:
4803+
case ISD::STRICT_FP_TO_BF16: {
4804+
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
4805+
if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
4806+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
4807+
else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
4808+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
4809+
else
4810+
LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4811+
Node->getValueType(0));
4812+
48014813
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
48024814

48034815
TargetLowering::MakeLibCallOptions CallOptions;

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

+27-14
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
918918
case ISD::STRICT_FP_TO_FP16:
919919
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
920920
case ISD::FP_TO_BF16:
921+
case ISD::STRICT_FP_TO_BF16:
921922
case ISD::STRICT_FP_ROUND:
922923
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break;
923924
case ISD::STRICT_FP_TO_SINT:
@@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
970971
assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 ||
971972
N->getOpcode() == ISD::STRICT_FP_TO_FP16 ||
972973
N->getOpcode() == ISD::FP_TO_BF16 ||
974+
N->getOpcode() == ISD::STRICT_FP_TO_BF16 ||
973975
N->getOpcode() == ISD::STRICT_FP_ROUND);
974976

975977
bool IsStrict = N->isStrictFPOpcode();
@@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
980982
if (N->getOpcode() == ISD::FP_TO_FP16 ||
981983
N->getOpcode() == ISD::STRICT_FP_TO_FP16)
982984
FloatRVT = MVT::f16;
983-
else if (N->getOpcode() == ISD::FP_TO_BF16)
985+
else if (N->getOpcode() == ISD::FP_TO_BF16 ||
986+
N->getOpcode() == ISD::STRICT_FP_TO_BF16)
984987
FloatRVT = MVT::bf16;
985988

986989
RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT);
@@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
21932196
if (RetVT == MVT::f16)
21942197
return ISD::STRICT_FP_TO_FP16;
21952198

2196-
if (OpVT == MVT::bf16) {
2197-
// TODO: return ISD::STRICT_BF16_TO_FP;
2198-
}
2199+
if (OpVT == MVT::bf16)
2200+
return ISD::STRICT_BF16_TO_FP;
21992201

2200-
if (RetVT == MVT::bf16) {
2201-
// TODO: return ISD::STRICT_FP_TO_BF16;
2202-
}
2202+
if (RetVT == MVT::bf16)
2203+
return ISD::STRICT_FP_TO_BF16;
22032204

22042205
report_fatal_error("Attempt at an invalid promotion-related conversion");
22052206
}
@@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
29993000
EVT SVT = N->getOperand(0).getValueType();
30003001

30013002
if (N->isStrictFPOpcode()) {
3002-
assert(RVT == MVT::f16);
3003-
SDValue Res =
3004-
DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
3005-
{N->getOperand(0), N->getOperand(1)});
3003+
// FIXME: assume we only have two f16 variants for now.
3004+
unsigned Opcode;
3005+
if (RVT == MVT::f16)
3006+
Opcode = ISD::STRICT_FP_TO_FP16;
3007+
else if (RVT == MVT::bf16)
3008+
Opcode = ISD::STRICT_FP_TO_BF16;
3009+
else
3010+
llvm_unreachable("unknown half type");
3011+
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
3012+
{N->getOperand(0), N->getOperand(1)});
30063013
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
30073014
return Res;
30083015
}
@@ -3192,10 +3199,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
31923199
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));
31933200

31943201
if (IsStrict) {
3195-
assert(SVT == MVT::f16);
3202+
unsigned Opcode;
3203+
if (SVT == MVT::f16)
3204+
Opcode = ISD::STRICT_FP16_TO_FP;
3205+
else if (SVT == MVT::bf16)
3206+
Opcode = ISD::STRICT_BF16_TO_FP;
3207+
else
3208+
llvm_unreachable("unknown half type");
31963209
SDValue Res =
3197-
DAG.getNode(ISD::STRICT_FP16_TO_FP, SDLoc(N),
3198-
{N->getValueType(0), MVT::Other}, {N->getOperand(0), Op});
3210+
DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
3211+
{N->getOperand(0), Op});
31993212
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
32003213
ReplaceValueWith(SDValue(N, 0), Res);
32013214
return SDValue();

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
165165
case ISD::FP_TO_FP16:
166166
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
167167
break;
168+
case ISD::STRICT_FP_TO_BF16:
168169
case ISD::STRICT_FP_TO_FP16:
169170
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
170171
break;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
380380
case ISD::FP_TO_FP16: return "fp_to_fp16";
381381
case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16";
382382
case ISD::BF16_TO_FP: return "bf16_to_fp";
383+
case ISD::STRICT_BF16_TO_FP: return "strict_bf16_to_fp";
383384
case ISD::FP_TO_BF16: return "fp_to_bf16";
385+
case ISD::STRICT_FP_TO_BF16: return "strict_fp_to_bf16";
384386
case ISD::LROUND: return "lround";
385387
case ISD::STRICT_LROUND: return "strict_lround";
386388
case ISD::LLROUND: return "llround";

llvm/lib/Target/X86/X86ISelLowering.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
393393
}
394394

395395
for (auto Op : {ISD::FP16_TO_FP, ISD::STRICT_FP16_TO_FP, ISD::FP_TO_FP16,
396-
ISD::STRICT_FP_TO_FP16}) {
396+
ISD::STRICT_FP_TO_FP16, ISD::STRICT_FP_TO_BF16}) {
397397
// Special handling for half-precision floating point conversions.
398398
// If we don't have F16C support, then lower half float conversions
399399
// into library calls.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=i686-linux-gnu | FileCheck %s --check-prefix=X32
3+
; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefix=X64
4+
5+
@a = global bfloat 0xR0000, align 2
6+
@b = global bfloat 0xR0000, align 2
7+
@c = global bfloat 0xR0000, align 2
8+
9+
; FIXME: We don't have strict extend yet.
10+
; define float @bfloat_to_float() strictfp {
11+
; %1 = load bfloat, ptr @a, align 2
12+
; %2 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
13+
; ret float %2
14+
; }
15+
16+
; define double @bfloat_to_double() strictfp {
17+
; %1 = load bfloat, ptr @a, align 2
18+
; %2 = tail call double @llvm.experimental.constrained.fpext.f64.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
19+
; ret double %2
20+
; }
21+
22+
define void @float_to_bfloat(float %0) strictfp {
23+
; X32-LABEL: float_to_bfloat:
24+
; X32: # %bb.0:
25+
; X32-NEXT: subl $12, %esp
26+
; X32-NEXT: .cfi_def_cfa_offset 16
27+
; X32-NEXT: flds {{[0-9]+}}(%esp)
28+
; X32-NEXT: fstps (%esp)
29+
; X32-NEXT: wait
30+
; X32-NEXT: calll __truncsfbf2
31+
; X32-NEXT: movw %ax, a
32+
; X32-NEXT: addl $12, %esp
33+
; X32-NEXT: .cfi_def_cfa_offset 4
34+
; X32-NEXT: retl
35+
;
36+
; X64-LABEL: float_to_bfloat:
37+
; X64: # %bb.0:
38+
; X64-NEXT: pushq %rax
39+
; X64-NEXT: .cfi_def_cfa_offset 16
40+
; X64-NEXT: callq __truncsfbf2@PLT
41+
; X64-NEXT: movq a@GOTPCREL(%rip), %rcx
42+
; X64-NEXT: movw %ax, (%rcx)
43+
; X64-NEXT: popq %rax
44+
; X64-NEXT: .cfi_def_cfa_offset 8
45+
; X64-NEXT: retq
46+
%2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
47+
store bfloat %2, ptr @a, align 2
48+
ret void
49+
}
50+
51+
define void @double_to_bfloat(double %0) strictfp {
52+
; X32-LABEL: double_to_bfloat:
53+
; X32: # %bb.0:
54+
; X32-NEXT: subl $12, %esp
55+
; X32-NEXT: .cfi_def_cfa_offset 16
56+
; X32-NEXT: fldl {{[0-9]+}}(%esp)
57+
; X32-NEXT: fstpl (%esp)
58+
; X32-NEXT: wait
59+
; X32-NEXT: calll __truncdfbf2
60+
; X32-NEXT: movw %ax, a
61+
; X32-NEXT: addl $12, %esp
62+
; X32-NEXT: .cfi_def_cfa_offset 4
63+
; X32-NEXT: retl
64+
;
65+
; X64-LABEL: double_to_bfloat:
66+
; X64: # %bb.0:
67+
; X64-NEXT: pushq %rax
68+
; X64-NEXT: .cfi_def_cfa_offset 16
69+
; X64-NEXT: callq __truncdfbf2@PLT
70+
; X64-NEXT: movq a@GOTPCREL(%rip), %rcx
71+
; X64-NEXT: movw %ax, (%rcx)
72+
; X64-NEXT: popq %rax
73+
; X64-NEXT: .cfi_def_cfa_offset 8
74+
; X64-NEXT: retq
75+
%2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
76+
store bfloat %2, ptr @a, align 2
77+
ret void
78+
}
79+
80+
; define void @add() strictfp {
81+
; %1 = load bfloat, ptr @a, align 2
82+
; %2 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %1, metadata !"fpexcept.strict") #0
83+
; %3 = load bfloat, ptr @b, align 2
84+
; %4 = tail call float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat %3, metadata !"fpexcept.strict") #0
85+
; %5 = tail call float @llvm.experimental.constrained.fadd.f32(float %2, float %4, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
86+
; %6 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %5, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
87+
; store bfloat %6, ptr @c, align 2
88+
; ret void
89+
; }
90+
91+
; declare float @llvm.experimental.constrained.fpext.f32.bfloat(bfloat, metadata)
92+
; declare double @llvm.experimental.constrained.fpext.f64.bfloat(bfloat, metadata)
93+
; declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata)
94+
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float, metadata, metadata)
95+
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double, metadata, metadata)
96+
97+
attributes #0 = { strictfp }
98+

0 commit comments

Comments
 (0)