Skip to content

[SelectionDAG] Add STRICT_BF16_TO_FP and STRICT_FP_TO_BF16 #80056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2024

Conversation

shiltian
Copy link
Contributor

@shiltian shiltian commented Jan 30, 2024

This patch adds the support for STRICT_BF16_TO_FP and STRICT_FP_TO_BF16.

@shiltian shiltian requested a review from arsenm January 30, 2024 20:35
@llvmbot llvmbot added backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Jan 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 30, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Shilei Tian (shiltian)

Changes

This patch adds the support for STRICT_BF16_TO_FP and STRICT_FP_TO_BF16.

Fix #78540.


Full diff: https://github.com/llvm/llvm-project/pull/80056.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+2)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+26-8)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+15-10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+4-2)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+46-5)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 349d1286c8dc4..29fa3bd842c14 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -921,6 +921,8 @@ enum NodeType {
   /// has native conversions.
   BF16_TO_FP,
   FP_TO_BF16,
+  STRICT_BF16_TO_FP,
+  STRICT_FP_TO_BF16,
 
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if the integer operand doesn't fit into
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 3130f6c4dce59..d1015630b05d1 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
         return false;
       case ISD::STRICT_FP16_TO_FP:
       case ISD::STRICT_FP_TO_FP16:
+      case ISD::STRICT_BF16_TO_FP:
+      case ISD::STRICT_FP_TO_BF16:
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
       case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index d29e44f95798c..beac23a070163 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1033,6 +1033,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
                                     Node->getOperand(0).getValueType());
     break;
   case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
   case ISD::STRICT_LRINT:
@@ -3248,12 +3249,18 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       Results.push_back(Tmp1);
     break;
   }
+  case ISD::STRICT_BF16_TO_FP:
+    // When strict mode is enforced we can't do expansion because it
+    // does not honor the "strict" properties.
+    if (TLI.isStrictFPEnabled())
+      break;
+    LLVM_FALLTHROUGH;
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
     // Note that the operand of this code can be bf16 or an integer type in case
     // bf16 is not supported on the target and was softened.
-    SDValue Op = Node->getOperand(0);
+    SDValue Op = Node->getOperand(Node->getOpcode() == ISD::BF16_TO_FP ? 0 : 1);
     if (Op.getValueType() == MVT::bf16) {
       Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32,
                        DAG.getNode(ISD::BITCAST, dl, MVT::i16, Op));
@@ -3271,8 +3278,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(Op);
     break;
   }
+  case ISD::STRICT_FP_TO_BF16:
+    // When strict mode is enforced we can't do expansion because it
+    // does not honor the "strict" properties.
+    if (TLI.isStrictFPEnabled())
+      break;
+    LLVM_FALLTHROUGH;
   case ISD::FP_TO_BF16: {
-    SDValue Op = Node->getOperand(0);
+    SDValue Op = Node->getOperand(Node->getOpcode() == ISD::FP_TO_BF16 ? 0 : 1);
     if (Op.getValueType() != MVT::f32)
       Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
@@ -4773,12 +4786,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
     break;
   }
   case ISD::STRICT_FP_EXTEND:
-  case ISD::STRICT_FP_TO_FP16: {
-    RTLIB::Libcall LC =
-        Node->getOpcode() == ISD::STRICT_FP_TO_FP16
-            ? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
-            : RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
-                              Node->getValueType(0));
+  case ISD::STRICT_FP_TO_FP16:
+  case ISD::STRICT_FP_TO_BF16: {
+    RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+    if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
+    else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
+      LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
+    else
+      LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
+                           Node->getValueType(0));
+
     assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
 
     TargetLowering::MakeLibCallOptions CallOptions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index f0a04589fbfdc..ea0696be8edc4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
   case ISD::STRICT_FP_TO_FP16:
   case ISD::FP_TO_FP16:  // Same as FP_ROUND for softening purposes
   case ISD::FP_TO_BF16:
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND:    Res = SoftenFloatOp_FP_ROUND(N); break;
   case ISD::STRICT_FP_TO_SINT:
@@ -2193,13 +2194,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
   if (RetVT == MVT::f16)
     return ISD::STRICT_FP_TO_FP16;
 
-  if (OpVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_BF16_TO_FP;
-  }
+  if (OpVT == MVT::bf16)
+    return ISD::STRICT_BF16_TO_FP;
 
-  if (RetVT == MVT::bf16) {
-    // TODO: return ISD::STRICT_FP_TO_BF16;
-  }
+  if (RetVT == MVT::bf16)
+    return ISD::STRICT_FP_TO_BF16;
 
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
@@ -2999,10 +2998,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
   EVT SVT = N->getOperand(0).getValueType();
 
   if (N->isStrictFPOpcode()) {
-    assert(RVT == MVT::f16);
-    SDValue Res =
-        DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
-                    {N->getOperand(0), N->getOperand(1)});
+    // FIXME: assume we only have two f16 variants for now.
+    unsigned Opcode;
+    if (RVT == MVT::f16)
+      Opcode = ISD::STRICT_FP_TO_FP16;
+    else if (RVT == MVT::bf16)
+      Opcode = ISD::STRICT_FP_TO_BF16;
+    else
+      llvm_unreachable("unknown half type");
+    SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
+                              {N->getOperand(0), N->getOperand(1)});
     ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
     return Res;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 814f746f5a4d9..62a21ad71b622 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
+  case ISD::STRICT_FP_TO_BF16:
   case ISD::STRICT_FP_TO_FP16:
     Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index a28d834f0522f..c0981d8362a3b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -379,7 +379,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::FP_TO_FP16:                 return "fp_to_fp16";
   case ISD::STRICT_FP_TO_FP16:          return "strict_fp_to_fp16";
   case ISD::BF16_TO_FP:                 return "bf16_to_fp";
+  case ISD::STRICT_BF16_TO_FP:          return "strict_bf16_to_fp";
   case ISD::FP_TO_BF16:                 return "fp_to_bf16";
+  case ISD::STRICT_FP_TO_BF16:          return "strict_fp_to_bf16";
   case ISD::LROUND:                     return "lround";
   case ISD::STRICT_LROUND:              return "strict_lround";
   case ISD::LLROUND:                    return "llround";
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 7ab062bcc4da7..d5ec49fb3114f 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -539,8 +539,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction({ISD::FSIN, ISD::FCOS, ISD::FDIV}, MVT::f32, Custom);
   setOperationAction(ISD::FDIV, MVT::f64, Custom);
 
-  setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
-  setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
+  setOperationAction({ISD::BF16_TO_FP, ISD::STRICT_BF16_TO_FP},
+                     {MVT::i16, MVT::f32, MVT::f64}, Expand);
+  setOperationAction({ISD::FP_TO_BF16, ISD::STRICT_FP_TO_BF16},
+                     {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
   // Custom lower these because we can't specify a rule based on an illegal
   // source bf16.
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
index 04bf2120b78cf..549b4d0fbd01f 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll
@@ -1094,11 +1094,52 @@ define <4 x i1> @isnan_v4bf16(<4 x bfloat> %x) nounwind {
   ret <4 x i1> %1
 }
 
-; FIXME: Broken for gfx6/7
-; define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
-;   %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
-;   ret i1 %1
-; }
+define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
+ ; GFX7CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX7CHECK:       ; %bb.0:
+ ; GFX7CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX7CHECK-NEXT:    v_bfe_u32 v0, v0, 16, 15
+ ; GFX7CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX7CHECK-NEXT:    v_cmp_lt_i32_e32 vcc, s4, v0
+ ; GFX7CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX7CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX8CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX8CHECK:       ; %bb.0:
+ ; GFX8CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX8CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX8CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX8CHECK-NEXT:    v_cmp_lt_i16_e32 vcc, s4, v0
+ ; GFX8CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX8CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX9CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX9CHECK:       ; %bb.0:
+ ; GFX9CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX9CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX9CHECK-NEXT:    s_movk_i32 s4, 0x7f80
+ ; GFX9CHECK-NEXT:    v_cmp_lt_i16_e32 vcc, s4, v0
+ ; GFX9CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc
+ ; GFX9CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX10CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX10CHECK:       ; %bb.0:
+ ; GFX10CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX10CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX10CHECK-NEXT:    v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
+ ; GFX10CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc_lo
+ ; GFX10CHECK-NEXT:    s_setpc_b64 s[30:31]
+ ;
+ ; GFX11CHECK-LABEL: isnan_bf16_strictfp:
+ ; GFX11CHECK:       ; %bb.0:
+ ; GFX11CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+ ; GFX11CHECK-NEXT:    v_and_b32_e32 v0, 0x7fff, v0
+ ; GFX11CHECK-NEXT:    v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
+ ; GFX11CHECK-NEXT:    v_cndmask_b32_e64 v0, 0, 1, vcc_lo
+ ; GFX11CHECK-NEXT:    s_setpc_b64 s[30:31]
+   %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
+   ret i1 %1
+ }
 
 define i1 @isinf_bf16(bfloat %x) nounwind {
 ; GFX7CHECK-LABEL: isinf_bf16:

@shiltian shiltian requested a review from Pierre-vh January 30, 2024 20:36
Copy link

github-actions bot commented Jan 30, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff ccc48d45b832def14c8bc1849cf64c805892368d 7eab42feee2c64574b25e9e8f0c658346dec3bd2 -- compiler-rt/lib/builtins/extendbfsf2.c compiler-rt/lib/builtins/fp_extend.h llvm/include/llvm/CodeGen/ISDOpcodes.h llvm/include/llvm/CodeGen/SelectionDAGNodes.h llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp llvm/lib/CodeGen/TargetLoweringBase.cpp llvm/lib/Target/X86/X86ISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 18ca17e53d..ec6f6dbcf0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,9 +380,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::FP_TO_FP16:                 return "fp_to_fp16";
   case ISD::STRICT_FP_TO_FP16:          return "strict_fp_to_fp16";
   case ISD::BF16_TO_FP:                 return "bf16_to_fp";
-  case ISD::STRICT_BF16_TO_FP:          return "strict_bf16_to_fp";
+  case ISD::STRICT_BF16_TO_FP:
+    return "strict_bf16_to_fp";
   case ISD::FP_TO_BF16:                 return "fp_to_bf16";
-  case ISD::STRICT_FP_TO_BF16:          return "strict_fp_to_bf16";
+  case ISD::STRICT_FP_TO_BF16:
+    return "strict_fp_to_bf16";
   case ISD::LROUND:                     return "lround";
   case ISD::STRICT_LROUND:              return "strict_lround";
   case ISD::LLROUND:                    return "llround";

@RKSimon RKSimon requested a review from phoebewang January 31, 2024 11:27
@shiltian shiltian force-pushed the PR78540 branch 2 times, most recently from 5162fb2 to 3c93e21 Compare January 31, 2024 19:59
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make progress, can you split the non-lib call expansion to a separate change? It's easier to start with the boilerplate and case that doesn't require writing new infrastructure

@shiltian shiltian force-pushed the PR78540 branch 2 times, most recently from e71b92d to b07da86 Compare February 10, 2024 00:35
@shiltian
Copy link
Contributor Author

shiltian commented Feb 10, 2024

To make progress, can you split the non-lib call expansion to a separate change? It's easier to start with the boilerplate and case that doesn't require writing new infrastructure

Removed the non-lib call expansion for now. However, I don't know how to test it now because there is no legal lowering that I can test.

@shiltian
Copy link
Contributor Author

shiltian commented Feb 28, 2024

@phoebewang Does X86 have lib calls for STRICT_BF16_TO_FP and STRICT_FP_TO_BF16? I'm trying to find a target to test the lowering via the lib call path.

@phoebewang
Copy link
Contributor

@phoebewang Does X86 have lib calls for STRICT_BF16_TO_FP and STRICT_FP_TO_BF16? I'm trying to find a target to test the lowering via the lib call path.

Do you mean this? https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/RuntimeLibcalls.def#L322

I see you use the same library as none strict one. But we don't provide libcall for extend.

@shiltian
Copy link
Contributor Author

I'm not sure how to design a valid lowering that can test the lib call path. Designed a similar X86 test case as the one from llvm/test/CodeGen/X86/half-constrained.ll, as shown below:

define void @float_to_bfloat(float %0) strictfp {
  %2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
  store bfloat %2, ptr @a, align 2
  ret void
}

If the type is half, it will go through the lib call path. However, if it is bfloat, it will not (I believe it goes through soft promotion).

@arsenm
Copy link
Contributor

arsenm commented Feb 29, 2024

define void @float_to_bfloat(float %0) strictfp {
%2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
store bfloat %2, ptr @A, align 2
ret void
}

For this example, I see an assert with -mtriple=x86_64--. "Assertion failed: (RVT == MVT::f16), function SoftPromoteHalfRes_FP_ROUND". For aarch64, I see a libcall to __truncsfbf2

@phoebewang
Copy link
Contributor

I'm not sure how to design a valid lowering that can test the lib call path. Designed a similar X86 test case as the one from llvm/test/CodeGen/X86/half-constrained.ll, as shown below:

define void @float_to_bfloat(float %0) strictfp {
  %2 = tail call bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float %0, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
  store bfloat %2, ptr @a, align 2
  ret void
}

If the type is half, it will go through the lib call path. However, if it is bfloat, it will not (I believe it goes through soft promotion).

The example can pass if you add expand to X86ISelLowering.cpp

setOperationAction(ISD::STRICT_FP_TO_BF16, MVT::f32, Expand);

; declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double, metadata, metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test vectors too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test STRICT_FP16_TO_FP? I think the key is we have to lower it to a new libcall instead of expand it through integer operation.
I'm also interested in test with -mattr=+avx512bf16,+avx512vl, because the VCVTNEPS2BF16 doesn't raise exception. We should lower it to libcall too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

STRICT_FP16_TO_FP would be out of the scope of this patch (as well as the file) since it already exists.

Copy link
Contributor

@phoebewang phoebewang Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mean to STRICT_BF16_TO_FP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any existing GNU lib functions for that? It looks like from https://gcc.gnu.org/pipermail/libstdc++-cvs/2023q1/039390.html that there might be?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we may need to add it to compiler-rt too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also interested in test with -mattr=+avx512bf16,+avx512vl, because the VCVTNEPS2BF16 doesn't raise exception. We should lower it to libcall too.

I'm not familiar with X86 backend. That can be done in a follow-up patch by someone with that expertise.

@@ -393,7 +393,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}

for (auto Op : {ISD::FP16_TO_FP, ISD::STRICT_FP16_TO_FP, ISD::FP_TO_FP16,
ISD::STRICT_FP_TO_FP16}) {
ISD::STRICT_FP_TO_FP16, ISD::STRICT_FP_TO_BF16}) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang Do you mean here, we need to drop Custom for Subtarget.hasF16C() if Op is ISD::STRICT_FP_TO_BF16? If yes, I will take it out of the loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to implement the Custom path since I'm not familiar with what instructions need to be used. I leave a FIXME there.

; declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double, metadata, metadata)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any existing GNU lib functions for that? It looks like from https://gcc.gnu.org/pipermail/libstdc++-cvs/2023q1/039390.html that there might be?

; declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f32(float, metadata, metadata)
declare bfloat @llvm.experimental.constrained.fptrunc.bfloat.f64(double, metadata, metadata)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also interested in test with -mattr=+avx512bf16,+avx512vl, because the VCVTNEPS2BF16 doesn't raise exception. We should lower it to libcall too.

I'm not familiar with X86 backend. That can be done in a follow-up patch by someone with that expertise.

@@ -81,6 +81,13 @@ static inline int src_rep_t_clz_impl(src_rep_t a) {

#define src_rep_t_clz src_rep_t_clz_impl

#elif defined SRC_BFLOAT
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in compiler-rt were from https://reviews.llvm.org/D151436.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt whether the implementation handles exception or not. But I don't have any suggestion to it.

; X32-NEXT: movzwl a, %eax
; X32-NEXT: movl %eax, (%esp)
; X32-NEXT: calll __extendbfsf2
; X32-NEXT: addl $12, %esp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems lacking a float to double conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the DAG. The extend node was indeed inserted, but I'm not sure why there is nothing emitted. Looking at llvm/test/CodeGen/X86/half-constrained.ll, it seems like for some reason on X86 target it just looks like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. It is using x87 registers. IIRC, the data is preserved in high precision in register. So we save an extra conversion.

@@ -81,6 +81,13 @@ static inline int src_rep_t_clz_impl(src_rep_t a) {

#define src_rep_t_clz src_rep_t_clz_impl

#elif defined SRC_BFLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt whether the implementation handles exception or not. But I don't have any suggestion to it.

@shiltian shiltian force-pushed the PR78540 branch 2 times, most recently from e2f46be to bd5196e Compare March 3, 2024 15:43
Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit, thanks!

; X32-NEXT: movzwl a, %eax
; X32-NEXT: movl %eax, (%esp)
; X32-NEXT: calll __extendbfsf2
; X32-NEXT: addl $12, %esp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. It is using x87 registers. IIRC, the data is preserved in high precision in register. So we save an extra conversion.

This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`.

Fix llvm#78540.
@shiltian shiltian merged commit b0c158b into llvm:main Mar 4, 2024
@shiltian shiltian deleted the PR78540 branch March 4, 2024 05:01
shiltian added a commit that referenced this pull request Mar 4, 2024
#80056)"

This reverts commit b0c158b.

The changes in `compiler-rt` broke tests.
shiltian added a commit that referenced this pull request Mar 4, 2024
This patch adds the support for `STRICT_BF16_TO_FP` and
`STRICT_FP_TO_BF16`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants