@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
918
918
case ISD::STRICT_FP_TO_FP16:
919
919
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
920
920
case ISD::FP_TO_BF16:
921
+ case ISD::STRICT_FP_TO_BF16:
921
922
case ISD::STRICT_FP_ROUND:
922
923
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND (N); break ;
923
924
case ISD::STRICT_FP_TO_SINT:
@@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
970
971
assert (N->getOpcode () == ISD::FP_ROUND || N->getOpcode () == ISD::FP_TO_FP16 ||
971
972
N->getOpcode () == ISD::STRICT_FP_TO_FP16 ||
972
973
N->getOpcode () == ISD::FP_TO_BF16 ||
974
+ N->getOpcode () == ISD::STRICT_FP_TO_BF16 ||
973
975
N->getOpcode () == ISD::STRICT_FP_ROUND);
974
976
975
977
bool IsStrict = N->isStrictFPOpcode ();
@@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
980
982
if (N->getOpcode () == ISD::FP_TO_FP16 ||
981
983
N->getOpcode () == ISD::STRICT_FP_TO_FP16)
982
984
FloatRVT = MVT::f16;
983
- else if (N->getOpcode () == ISD::FP_TO_BF16)
985
+ else if (N->getOpcode () == ISD::FP_TO_BF16 ||
986
+ N->getOpcode () == ISD::STRICT_FP_TO_BF16)
984
987
FloatRVT = MVT::bf16;
985
988
986
989
RTLIB::Libcall LC = RTLIB::getFPROUND (SVT, FloatRVT);
@@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
2193
2196
if (RetVT == MVT::f16)
2194
2197
return ISD::STRICT_FP_TO_FP16;
2195
2198
2196
- if (OpVT == MVT::bf16) {
2197
- // TODO: return ISD::STRICT_BF16_TO_FP;
2198
- }
2199
+ if (OpVT == MVT::bf16)
2200
+ return ISD::STRICT_BF16_TO_FP;
2199
2201
2200
- if (RetVT == MVT::bf16) {
2201
- // TODO: return ISD::STRICT_FP_TO_BF16;
2202
- }
2202
+ if (RetVT == MVT::bf16)
2203
+ return ISD::STRICT_FP_TO_BF16;
2203
2204
2204
2205
report_fatal_error (" Attempt at an invalid promotion-related conversion" );
2205
2206
}
@@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
2999
3000
EVT SVT = N->getOperand (0 ).getValueType ();
3000
3001
3001
3002
if (N->isStrictFPOpcode ()) {
3002
- assert (RVT == MVT::f16);
3003
- SDValue Res =
3004
- DAG.getNode (ISD::STRICT_FP_TO_FP16, SDLoc (N), {MVT::i16, MVT::Other},
3005
- {N->getOperand (0 ), N->getOperand (1 )});
3003
+ // FIXME: assume we only have two f16 variants for now.
3004
+ unsigned Opcode;
3005
+ if (RVT == MVT::f16)
3006
+ Opcode = ISD::STRICT_FP_TO_FP16;
3007
+ else if (RVT == MVT::bf16)
3008
+ Opcode = ISD::STRICT_FP_TO_BF16;
3009
+ else
3010
+ llvm_unreachable (" unknown half type" );
3011
+ SDValue Res = DAG.getNode (Opcode, SDLoc (N), {MVT::i16, MVT::Other},
3012
+ {N->getOperand (0 ), N->getOperand (1 )});
3006
3013
ReplaceValueWith (SDValue (N, 1 ), Res.getValue (1 ));
3007
3014
return Res;
3008
3015
}
@@ -3192,10 +3199,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
3192
3199
Op = GetSoftPromotedHalf (N->getOperand (IsStrict ? 1 : 0 ));
3193
3200
3194
3201
if (IsStrict) {
3195
- assert (SVT == MVT::f16);
3202
+ unsigned Opcode;
3203
+ if (SVT == MVT::f16)
3204
+ Opcode = ISD::STRICT_FP16_TO_FP;
3205
+ else if (SVT == MVT::bf16)
3206
+ Opcode = ISD::STRICT_BF16_TO_FP;
3207
+ else
3208
+ llvm_unreachable (" unknown half type" );
3196
3209
SDValue Res =
3197
- DAG.getNode (ISD::STRICT_FP16_TO_FP , SDLoc (N),
3198
- {N->getValueType ( 0 ), MVT::Other}, {N-> getOperand (0 ), Op});
3210
+ DAG.getNode (Opcode , SDLoc (N), {N-> getValueType ( 0 ), MVT::Other} ,
3211
+ {N->getOperand (0 ), Op});
3199
3212
ReplaceValueWith (SDValue (N, 1 ), Res.getValue (1 ));
3200
3213
ReplaceValueWith (SDValue (N, 0 ), Res);
3201
3214
return SDValue ();
0 commit comments