@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
918918 case ISD::STRICT_FP_TO_FP16:
919919 case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
920920 case ISD::FP_TO_BF16:
921+ case ISD::STRICT_FP_TO_BF16:
921922 case ISD::STRICT_FP_ROUND:
922923 case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND (N); break ;
923924 case ISD::STRICT_FP_TO_SINT:
@@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
970971 assert (N->getOpcode () == ISD::FP_ROUND || N->getOpcode () == ISD::FP_TO_FP16 ||
971972 N->getOpcode () == ISD::STRICT_FP_TO_FP16 ||
972973 N->getOpcode () == ISD::FP_TO_BF16 ||
974+ N->getOpcode () == ISD::STRICT_FP_TO_BF16 ||
973975 N->getOpcode () == ISD::STRICT_FP_ROUND);
974976
975977 bool IsStrict = N->isStrictFPOpcode ();
@@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
980982 if (N->getOpcode () == ISD::FP_TO_FP16 ||
981983 N->getOpcode () == ISD::STRICT_FP_TO_FP16)
982984 FloatRVT = MVT::f16 ;
983- else if (N->getOpcode () == ISD::FP_TO_BF16)
985+ else if (N->getOpcode () == ISD::FP_TO_BF16 ||
986+ N->getOpcode () == ISD::STRICT_FP_TO_BF16)
984987 FloatRVT = MVT::bf16 ;
985988
986989 RTLIB::Libcall LC = RTLIB::getFPROUND (SVT, FloatRVT);
@@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
21932196 if (RetVT == MVT::f16 )
21942197 return ISD::STRICT_FP_TO_FP16;
21952198
2196- if (OpVT == MVT::bf16 ) {
2197- // TODO: return ISD::STRICT_BF16_TO_FP;
2198- }
2199+ if (OpVT == MVT::bf16 )
2200+ return ISD::STRICT_BF16_TO_FP;
21992201
2200- if (RetVT == MVT::bf16 ) {
2201- // TODO: return ISD::STRICT_FP_TO_BF16;
2202- }
2202+ if (RetVT == MVT::bf16 )
2203+ return ISD::STRICT_FP_TO_BF16;
22032204
22042205 report_fatal_error (" Attempt at an invalid promotion-related conversion" );
22052206}
@@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
29993000 EVT SVT = N->getOperand (0 ).getValueType ();
30003001
30013002 if (N->isStrictFPOpcode ()) {
3002- assert (RVT == MVT::f16 );
3003- SDValue Res =
3004- DAG.getNode (ISD::STRICT_FP_TO_FP16, SDLoc (N), {MVT::i16 , MVT::Other},
3005- {N->getOperand (0 ), N->getOperand (1 )});
3003+ // FIXME: assume we only have two f16 variants for now.
3004+ unsigned Opcode;
3005+ if (RVT == MVT::f16 )
3006+ Opcode = ISD::STRICT_FP_TO_FP16;
3007+ else if (RVT == MVT::bf16 )
3008+ Opcode = ISD::STRICT_FP_TO_BF16;
3009+ else
3010+ llvm_unreachable (" unknown half type" );
3011+ SDValue Res = DAG.getNode (Opcode, SDLoc (N), {MVT::i16 , MVT::Other},
3012+ {N->getOperand (0 ), N->getOperand (1 )});
30063013 ReplaceValueWith (SDValue (N, 1 ), Res.getValue (1 ));
30073014 return Res;
30083015 }
@@ -3192,10 +3199,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
31923199 Op = GetSoftPromotedHalf (N->getOperand (IsStrict ? 1 : 0 ));
31933200
31943201 if (IsStrict) {
3195- assert (SVT == MVT::f16 );
3202+ unsigned Opcode;
3203+ if (SVT == MVT::f16 )
3204+ Opcode = ISD::STRICT_FP16_TO_FP;
3205+ else if (SVT == MVT::bf16 )
3206+ Opcode = ISD::STRICT_BF16_TO_FP;
3207+ else
3208+ llvm_unreachable (" unknown half type" );
31963209 SDValue Res =
3197- DAG.getNode (ISD::STRICT_FP16_TO_FP , SDLoc (N),
3198- {N->getValueType ( 0 ), MVT::Other}, {N-> getOperand (0 ), Op});
3210+ DAG.getNode (Opcode , SDLoc (N), {N-> getValueType ( 0 ), MVT::Other} ,
3211+ {N->getOperand (0 ), Op});
31993212 ReplaceValueWith (SDValue (N, 1 ), Res.getValue (1 ));
32003213 ReplaceValueWith (SDValue (N, 0 ), Res);
32013214 return SDValue ();
0 commit comments