Skip to content

Commit e71b92d

Browse files
committed
[SelectionDAG] Add STRICT_BF16_TO_FP and STRICT_FP_TO_BF16
This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`. Fix #78540.
1 parent 3a3302e commit e71b92d

File tree

6 files changed

+45
-19
lines changed

6 files changed

+45
-19
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

+2
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ enum NodeType {
921921
/// has native conversions.
922922
BF16_TO_FP,
923923
FP_TO_BF16,
924+
STRICT_BF16_TO_FP,
925+
STRICT_FP_TO_BF16,
924926

925927
/// Perform various unary floating-point operations inspired by libm. For
926928
/// FPOWI, the result is undefined if the integer operand doesn't fit into

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

+2
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
698698
return false;
699699
case ISD::STRICT_FP16_TO_FP:
700700
case ISD::STRICT_FP_TO_FP16:
701+
case ISD::STRICT_BF16_TO_FP:
702+
case ISD::STRICT_FP_TO_BF16:
701703
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
702704
case ISD::STRICT_##DAGN:
703705
#include "llvm/IR/ConstrainedOps.def"

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

+23-9
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
10341034
Node->getOperand(0).getValueType());
10351035
break;
10361036
case ISD::STRICT_FP_TO_FP16:
1037+
case ISD::STRICT_FP_TO_BF16:
10371038
case ISD::STRICT_SINT_TO_FP:
10381039
case ISD::STRICT_UINT_TO_FP:
10391040
case ISD::STRICT_LRINT:
@@ -3263,12 +3264,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32633264
Results.push_back(Tmp1);
32643265
break;
32653266
}
3267+
case ISD::STRICT_BF16_TO_FP:
3268+
// We don't support this expansion for now.
3269+
break;
32663270
case ISD::BF16_TO_FP: {
32673271
// Always expand bf16 to f32 casts, they lower to ext + shift.
32683272
//
32693273
// Note that the operand of this code can be bf16 or an integer type in case
32703274
// bf16 is not supported on the target and was softened.
3271-
SDValue Op = Node->getOperand(0);
3275+
SDValue Op = Node->getOperand(Node->getOpcode() == ISD::BF16_TO_FP ? 0 : 1);
32723276
if (Op.getValueType() == MVT::bf16) {
32733277
Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32,
32743278
DAG.getNode(ISD::BITCAST, dl, MVT::i16, Op));
@@ -3286,10 +3290,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32863290
Results.push_back(Op);
32873291
break;
32883292
}
3293+
case ISD::STRICT_FP_TO_BF16:
3294+
// We don't support this expansion for now.
3295+
break;
32893296
case ISD::FP_TO_BF16: {
3290-
SDValue Op = Node->getOperand(0);
3297+
bool IsStrictFP = Node->getOpcode() == ISD::STRICT_FP_TO_BF16;
3298+
SDValue Op = Node->getOperand(IsStrictFP ? 1 : 0);
32913299
if (Op.getValueType() != MVT::f32)
3292-
Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
3300+
Op = DAG.getNode(IsStrictFP ? ISD::STRICT_FP_ROUND : ISD::FP_ROUND, dl,
3301+
MVT::f32, Op,
32933302
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
32943303
Op = DAG.getNode(
32953304
ISD::SRL, dl, MVT::i32, DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op),
@@ -4788,12 +4797,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
47884797
break;
47894798
}
47904799
case ISD::STRICT_FP_EXTEND:
4791-
case ISD::STRICT_FP_TO_FP16: {
4792-
RTLIB::Libcall LC =
4793-
Node->getOpcode() == ISD::STRICT_FP_TO_FP16
4794-
? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
4795-
: RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4796-
Node->getValueType(0));
4800+
case ISD::STRICT_FP_TO_FP16:
4801+
case ISD::STRICT_FP_TO_BF16: {
4802+
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
4803+
if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
4804+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
4805+
else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
4806+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
4807+
else
4808+
LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4809+
Node->getValueType(0));
4810+
47974811
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
47984812

47994813
TargetLowering::MakeLibCallOptions CallOptions;

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
918918
case ISD::STRICT_FP_TO_FP16:
919919
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
920920
case ISD::FP_TO_BF16:
921+
case ISD::STRICT_FP_TO_BF16:
921922
case ISD::STRICT_FP_ROUND:
922923
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break;
923924
case ISD::STRICT_FP_TO_SINT:
@@ -2193,13 +2194,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
21932194
if (RetVT == MVT::f16)
21942195
return ISD::STRICT_FP_TO_FP16;
21952196

2196-
if (OpVT == MVT::bf16) {
2197-
// TODO: return ISD::STRICT_BF16_TO_FP;
2198-
}
2197+
if (OpVT == MVT::bf16)
2198+
return ISD::STRICT_BF16_TO_FP;
21992199

2200-
if (RetVT == MVT::bf16) {
2201-
// TODO: return ISD::STRICT_FP_TO_BF16;
2202-
}
2200+
if (RetVT == MVT::bf16)
2201+
return ISD::STRICT_FP_TO_BF16;
22032202

22042203
report_fatal_error("Attempt at an invalid promotion-related conversion");
22052204
}
@@ -2999,10 +2998,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
29992998
EVT SVT = N->getOperand(0).getValueType();
30002999

30013000
if (N->isStrictFPOpcode()) {
3002-
assert(RVT == MVT::f16);
3003-
SDValue Res =
3004-
DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
3005-
{N->getOperand(0), N->getOperand(1)});
3001+
// FIXME: assume we only have two f16 variants for now.
3002+
unsigned Opcode;
3003+
if (RVT == MVT::f16)
3004+
Opcode = ISD::STRICT_FP_TO_FP16;
3005+
else if (RVT == MVT::bf16)
3006+
Opcode = ISD::STRICT_FP_TO_BF16;
3007+
else
3008+
llvm_unreachable("unknown half type");
3009+
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
3010+
{N->getOperand(0), N->getOperand(1)});
30063011
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
30073012
return Res;
30083013
}

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
165165
case ISD::FP_TO_FP16:
166166
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
167167
break;
168+
case ISD::STRICT_FP_TO_BF16:
168169
case ISD::STRICT_FP_TO_FP16:
169170
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
170171
break;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
379379
case ISD::FP_TO_FP16: return "fp_to_fp16";
380380
case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16";
381381
case ISD::BF16_TO_FP: return "bf16_to_fp";
382+
case ISD::STRICT_BF16_TO_FP: return "strict_bf16_to_fp";
382383
case ISD::FP_TO_BF16: return "fp_to_bf16";
384+
case ISD::STRICT_FP_TO_BF16: return "strict_fp_to_bf16";
383385
case ISD::LROUND: return "lround";
384386
case ISD::STRICT_LROUND: return "strict_lround";
385387
case ISD::LLROUND: return "llround";

0 commit comments

Comments
 (0)