Skip to content

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 4, 2023

Needs an AMDGPU hack to get the selection to work. The ordinary
variant is custom lowered through an almost equivalent target node
that would need a strict variant for additional known bits optimizations.

Test is a placeholder, will be merged into the existing
test after additional bug fixes for illegal f16 targets are fixed.
@llvmbot llvmbot added backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Dec 4, 2023
@arsenm arsenm added floating-point Floating-point math and removed backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Dec 4, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

Needs an AMDGPU hack to get the selection to work. The ordinary
variant is custom lowered through an almost equivalent target node
that would need a strict variant for additional known bits optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/74332.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+63)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1)
  • (added) llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll (+110)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..d7a688511b726 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2181,6 +2181,20 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16) {
+    return ISD::STRICT_FP16_TO_FP;
+  } else if (RetVT == MVT::f16) {
+    return ISD::STRICT_FP_TO_FP16;
+  } else if (OpVT == MVT::bf16) {
+    // return ISD::STRICT_BF16_TO_FP;
+  } else if (RetVT == MVT::bf16) {
+    // return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2228,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2293,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1);
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2430,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2638,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll b/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll
new file mode 100644
index 0000000000000..8a3647a9b6e93
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=amdgcn-mesa-mesa3d -mcpu=hawaii < %s | FileCheck -check-prefixes=GFX7 %s
+
+declare float @llvm.experimental.constrained.fpext.f32.f16(half, metadata) #0
+declare <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half>, metadata) #0
+declare half @llvm.experimental.constrained.fptrunc.f16.f32(float, metadata, metadata) #0
+declare <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float>, metadata, metadata) #0
+declare float @llvm.fabs.f32(float)
+
+define float @v_constrained_fpext_f16_to_f32(ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fpext_f16_to_f32:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    buffer_load_ushort v0, v[0:1], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %val = load half, ptr addrspace(1) %ptr
+  %result = call float @llvm.experimental.constrained.fpext.f32.f16(half %val, metadata !"fpexcept.strict")
+  ret float %result
+}
+
+define <2 x float> @v_constrained_fpext_v2f16_to_v2f32(ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fpext_v2f16_to_v2f32:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    buffer_load_dword v1, v[0:1], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v1
+; GFX7-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %val = load <2 x half>, ptr addrspace(1) %ptr
+  %result = call <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half> %val, metadata !"fpexcept.strict")
+  ret <2 x float> %result
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+define void @v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict(<2 x float> %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v1, v1
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    v_and_b32_e32 v1, 0xffff, v1
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v1, v1
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
+; GFX7-NEXT:    v_or_b32_e32 v0, v0, v1
+; GFX7-NEXT:    buffer_store_dword v0, v[2:3], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %result = call <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float> %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  store <2 x half> %result, ptr addrspace(1) %ptr
+  ret void
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e64 v0, -v0
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %neg.arg = fneg float %arg
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %neg.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e64 v0, |v0|
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %abs.arg = call float @llvm.fabs.f32(float %arg)
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %abs.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+attributes #0 = { strictfp }

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-llvm-selectiondag

Author: Matt Arsenault (arsenm)

Changes

Needs an AMDGPU hack to get the selection to work. The ordinary
variant is custom lowered through an almost equivalent target node
that would need a strict variant for additional known bits optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/74332.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+63)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+8-1)
  • (added) llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll (+110)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 798e6a1d9525e..de7bf26868b34 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
 def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
                                SDTIntToFPOp, [SDNPHasChain]>;
+
+def strict_f16_to_fp  : SDNode<"ISD::STRICT_FP16_TO_FP",
+                               SDTIntToFPOp, [SDNPHasChain]>;
+def strict_fp_to_f16  : SDNode<"ISD::STRICT_FP_TO_FP16",
+                               SDTFPToIntOp, [SDNPHasChain]>;
+
 def strict_fsetcc  : SDNode<"ISD::STRICT_FSETCC",  SDTSetCC, [SDNPHasChain]>;
 def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
 
@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
                           [(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
                            (setcc node:$lhs, node:$rhs, node:$pred)]>;
 
+def any_f16_to_fp : PatFrags<(ops node:$src),
+                              [(f16_to_fp node:$src),
+                               (strict_f16_to_fp node:$src)]>;
+def any_fp_to_f16 : PatFrags<(ops node:$src),
+                              [(fp_to_f16 node:$src),
+                               (strict_fp_to_f16 node:$src)]>;
+
 multiclass binary_atomic_op_ord {
   def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
       (!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b9..d7a688511b726 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2181,6 +2181,20 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
   report_fatal_error("Attempt at an invalid promotion-related conversion");
 }
 
+static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
+  if (OpVT == MVT::f16) {
+    return ISD::STRICT_FP16_TO_FP;
+  } else if (RetVT == MVT::f16) {
+    return ISD::STRICT_FP_TO_FP16;
+  } else if (OpVT == MVT::bf16) {
+    // return ISD::STRICT_BF16_TO_FP;
+  } else if (RetVT == MVT::bf16) {
+    // return ISD::STRICT_FP_TO_BF16;
+  }
+
+  report_fatal_error("Attempt at an invalid promotion-related conversion");
+}
+
 bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
   LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
   SDValue R = SDValue();
@@ -2214,6 +2228,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::FP_TO_UINT_SAT:
                           R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
+    case ISD::STRICT_FP_EXTEND:
+      R = PromoteFloatOp_STRICT_FP_EXTEND(N, OpNo);
+      break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
     case ISD::STORE:      R = PromoteFloatOp_STORE(N, OpNo); break;
@@ -2276,6 +2293,26 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo) {
   return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, Op);
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N,
+                                                          unsigned OpNo) {
+  assert(OpNo == 1);
+
+  SDValue Op = GetPromotedFloat(N->getOperand(1));
+  EVT VT = N->getValueType(0);
+
+  // Desired VT is same as promoted type.  Use promoted float directly.
+  if (VT == Op->getValueType(0)) {
+    ReplaceValueWith(SDValue(N, 1), N->getOperand(0));
+    return Op;
+  }
+
+  // Else, extend the promoted float value to the desired VT.
+  SDValue Res = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), N->getVTList(),
+                            N->getOperand(0), Op);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 // Promote the float operands used for comparison.  The true- and false-
 // operands have the same type as the result and are promoted, if needed, by
 // PromoteFloatRes_SELECT_CC
@@ -2393,6 +2430,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::FFREXP:     R = PromoteFloatRes_FFREXP(N); break;
 
     case ISD::FP_ROUND:   R = PromoteFloatRes_FP_ROUND(N); break;
+    case ISD::STRICT_FP_ROUND:
+      R = PromoteFloatRes_STRICT_FP_ROUND(N);
+      break;
     case ISD::LOAD:       R = PromoteFloatRes_LOAD(N); break;
     case ISD::SELECT:     R = PromoteFloatRes_SELECT(N); break;
     case ISD::SELECT_CC:  R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2598,6 +2638,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
   return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
 }
 
+// Explicit operation to reduce precision.  Reduce the value to half precision
+// and promote it back to the legal type.
+SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
+  SDLoc DL(N);
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Op = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT OpVT = Op->getValueType(0);
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+
+  // Round promoted float to desired precision
+  SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
+                              DAG.getVTList(IVT, MVT::Other), Chain, Op);
+  // Promote it back to the legal output type
+  SDValue Res =
+      DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
+                  DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
   LoadSDNode *L = cast<LoadSDNode>(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 54698edce7d6f..887670fb6baff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
     break;
-
+  case ISD::STRICT_FP_TO_FP16:
+    Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
+    break;
   case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
 
   case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
   return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDLoc dl(N);
+
+  SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
+                            N->getOperand(0), N->getOperand(1));
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP16_TO_FP:
   case ISD::VP_UINT_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
+  case ISD::STRICT_FP16_TO_FP:
   case ISD::STRICT_UINT_TO_FP:  Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
   case ISD::ZERO_EXTEND:  Res = PromoteIntOp_ZERO_EXTEND(N); break;
   case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e9bd54089d062..9361e7fff2190 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
+  SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
   SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatRes_ExpOp(SDNode *N);
   SDValue PromoteFloatRes_FFREXP(SDNode *N);
   SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
+  SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
   SDValue PromoteFloatRes_LOAD(SDNode *N);
   SDValue PromoteFloatRes_SELECT(SDNode *N);
   SDValue PromoteFloatRes_SELECT_CC(SDNode *N);
@@ -712,6 +714,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_STRICT_FP_EXTEND(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 9362fe5d9678b..d23d6b4f93da8 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -1097,7 +1097,7 @@ def : Pat <
 multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
   // f16_to_fp patterns
   def : GCNPat <
-    (f32 (f16_to_fp i32:$src0)),
+    (f32 (any_f16_to_fp i32:$src0)),
     (cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
   >;
 
@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
     (f16 (uint_to_fp i32:$src)),
     (cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
   >;
+
+  // This is only used on targets without half support
+  // TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
+  def : GCNPat <
+    (i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
+    (cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
+  >;
 }
 
 let SubtargetPredicate = NotHasTrue16BitInsts in
diff --git a/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll b/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll
new file mode 100644
index 0000000000000..8a3647a9b6e93
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=amdgcn-mesa-mesa3d -mcpu=hawaii < %s | FileCheck -check-prefixes=GFX7 %s
+
+declare float @llvm.experimental.constrained.fpext.f32.f16(half, metadata) #0
+declare <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half>, metadata) #0
+declare half @llvm.experimental.constrained.fptrunc.f16.f32(float, metadata, metadata) #0
+declare <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float>, metadata, metadata) #0
+declare float @llvm.fabs.f32(float)
+
+define float @v_constrained_fpext_f16_to_f32(ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fpext_f16_to_f32:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    buffer_load_ushort v0, v[0:1], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %val = load half, ptr addrspace(1) %ptr
+  %result = call float @llvm.experimental.constrained.fpext.f32.f16(half %val, metadata !"fpexcept.strict")
+  ret float %result
+}
+
+define <2 x float> @v_constrained_fpext_v2f16_to_v2f32(ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fpext_v2f16_to_v2f32:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    buffer_load_dword v1, v[0:1], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v1
+; GFX7-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %val = load <2 x half>, ptr addrspace(1) %ptr
+  %result = call <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half> %val, metadata !"fpexcept.strict")
+  ret <2 x float> %result
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+define void @v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict(<2 x float> %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v1, v1
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    s_mov_b32 s6, 0
+; GFX7-NEXT:    s_mov_b32 s7, 0xf000
+; GFX7-NEXT:    v_and_b32_e32 v1, 0xffff, v1
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_mov_b32 s4, s6
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v1, v1
+; GFX7-NEXT:    s_mov_b32 s5, s6
+; GFX7-NEXT:    v_cvt_f16_f32_e32 v0, v0
+; GFX7-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
+; GFX7-NEXT:    v_or_b32_e32 v0, v0, v1
+; GFX7-NEXT:    buffer_store_dword v0, v[2:3], s[4:7], 0 addr64
+; GFX7-NEXT:    s_waitcnt vmcnt(0)
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %result = call <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float> %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  store <2 x half> %result, ptr addrspace(1) %ptr
+  ret void
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e64 v0, -v0
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %neg.arg = fneg float %arg
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %neg.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs(float %arg, ptr addrspace(1) %ptr) #0 {
+; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs:
+; GFX7:       ; %bb.0:
+; GFX7-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX7-NEXT:    v_cvt_f16_f32_e64 v0, |v0|
+; GFX7-NEXT:    v_and_b32_e32 v0, 0xffff, v0
+; GFX7-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX7-NEXT:    s_setpc_b64 s[30:31]
+  %abs.arg = call float @llvm.fabs.f32(float %arg)
+  %result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %abs.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
+  ret void
+}
+
+attributes #0 = { strictfp }

Needs an AMDGPU hack to get the selection to work. The ordinary
variant is custom lowered through an almost equivalent target node
that would need a strict variant for additional known bits optimizations.
Copy link

github-actions bot commented Dec 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@arsenm arsenm force-pushed the dag-promote-strict-fptrunc branch from 66316f8 to ffb23d6 Compare December 6, 2023 04:40
@llvmbot llvmbot added backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Dec 6, 2023
@arsenm
Copy link
Contributor Author

arsenm commented Jan 4, 2024

ping

@andykaylor andykaylor requested a review from phoebewang January 4, 2024 18:56
Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. I've added Phoebe Wang as a reviewer because she knows the SelectionDAG code better than I do and might want to take a look. I'm assuming you have reviewed the AMDGPU-specific parts of this with someone who knows that architecture.

DAG.getVTList(IVT, MVT::Other), Chain, Op);
// Promote it back to the legal output type
SDValue Res =
DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetPromotionOpcode() is used in a lot more places. Are those not needed for strict or are they coming in later patches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They all need updating at some point

@arsenm arsenm merged commit 597086c into llvm:main Jan 5, 2024
@arsenm arsenm deleted the dag-promote-strict-fptrunc branch January 5, 2024 01:44
@@ -1097,7 +1097,7 @@ def : Pat <
multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
// f16_to_fp patterns
def : GCNPat <
(f32 (f16_to_fp i32:$src0)),
(f32 (any_f16_to_fp i32:$src0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we replace more f16_to_fp to any_f16_to_fp in this file and other target files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with appropriate tests added alongside it

@@ -2621,6 +2642,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
}

// Explicit operation to reduce precision. Reduce the value to half precision
// and promote it back to the legal type.
SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two promotion legalizers, Promote and SoftPromoteHalf. X86 and RISC-V use SoftPromoteHalf. Other targets use Promote. Promote doesn't convert back to fp16 in the right places in my opinion. If I recall correctly Promote doesn't convert to fp16 between two promoted fadds for example so it keeps extra precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU floating-point Floating-point math llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants