@@ -1034,6 +1034,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
1034
1034
Node->getOperand (0 ).getValueType ());
1035
1035
break ;
1036
1036
case ISD::STRICT_FP_TO_FP16:
1037
+ case ISD::STRICT_FP_TO_BF16:
1037
1038
case ISD::STRICT_SINT_TO_FP:
1038
1039
case ISD::STRICT_UINT_TO_FP:
1039
1040
case ISD::STRICT_LRINT:
@@ -3263,12 +3264,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
3263
3264
Results.push_back (Tmp1);
3264
3265
break ;
3265
3266
}
3267
+ case ISD::STRICT_BF16_TO_FP:
3268
+ // We don't support this expansion for now.
3269
+ break ;
3266
3270
case ISD::BF16_TO_FP: {
3267
3271
// Always expand bf16 to f32 casts, they lower to ext + shift.
3268
3272
//
3269
3273
// Note that the operand of this code can be bf16 or an integer type in case
3270
3274
// bf16 is not supported on the target and was softened.
3271
- SDValue Op = Node->getOperand (0 );
3275
+ SDValue Op = Node->getOperand (Node-> getOpcode () == ISD::BF16_TO_FP ? 0 : 1 );
3272
3276
if (Op.getValueType () == MVT::bf16) {
3273
3277
Op = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i32,
3274
3278
DAG.getNode (ISD::BITCAST, dl, MVT::i16, Op));
@@ -3286,10 +3290,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
3286
3290
Results.push_back (Op);
3287
3291
break ;
3288
3292
}
3293
+ case ISD::STRICT_FP_TO_BF16:
3294
+ // We don't support this expansion for now.
3295
+ break ;
3289
3296
case ISD::FP_TO_BF16: {
3290
- SDValue Op = Node->getOperand (0 );
3297
+ bool IsStrictFP = Node->getOpcode () == ISD::STRICT_FP_TO_BF16;
3298
+ SDValue Op = Node->getOperand (IsStrictFP ? 1 : 0 );
3291
3299
if (Op.getValueType () != MVT::f32)
3292
- Op = DAG.getNode (ISD::FP_ROUND, dl, MVT::f32, Op,
3300
+ Op = DAG.getNode (IsStrictFP ? ISD::STRICT_FP_ROUND : ISD::FP_ROUND, dl,
3301
+ MVT::f32, Op,
3293
3302
DAG.getIntPtrConstant (0 , dl, /* isTarget=*/ true ));
3294
3303
Op = DAG.getNode (
3295
3304
ISD::SRL, dl, MVT::i32, DAG.getNode (ISD::BITCAST, dl, MVT::i32, Op),
@@ -4788,12 +4797,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
4788
4797
break ;
4789
4798
}
4790
4799
case ISD::STRICT_FP_EXTEND:
4791
- case ISD::STRICT_FP_TO_FP16: {
4792
- RTLIB::Libcall LC =
4793
- Node->getOpcode () == ISD::STRICT_FP_TO_FP16
4794
- ? RTLIB::getFPROUND (Node->getOperand (1 ).getValueType (), MVT::f16)
4795
- : RTLIB::getFPEXT (Node->getOperand (1 ).getValueType (),
4796
- Node->getValueType (0 ));
4800
+ case ISD::STRICT_FP_TO_FP16:
4801
+ case ISD::STRICT_FP_TO_BF16: {
4802
+ RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
4803
+ if (Node->getOpcode () == ISD::STRICT_FP_TO_FP16)
4804
+ LC = RTLIB::getFPROUND (Node->getOperand (1 ).getValueType (), MVT::f16);
4805
+ else if (Node->getOpcode () == ISD::STRICT_FP_TO_BF16)
4806
+ LC = RTLIB::getFPROUND (Node->getOperand (1 ).getValueType (), MVT::bf16);
4807
+ else
4808
+ LC = RTLIB::getFPEXT (Node->getOperand (1 ).getValueType (),
4809
+ Node->getValueType (0 ));
4810
+
4797
4811
assert (LC != RTLIB::UNKNOWN_LIBCALL && " Unable to legalize as libcall" );
4798
4812
4799
4813
TargetLowering::MakeLibCallOptions CallOptions;
0 commit comments