Skip to content

Commit f381fe1

Browse files
committed
DAG: Implement promotion for strict_fp_round
Needs an AMDGPU hack to get the selection to work. The ordinary variant is custom lowered through an almost equivalent target node that would need a strict variant for additional known bits optimizations.
1 parent 405b870 commit f381fe1

File tree

6 files changed

+144
-2
lines changed

6 files changed

+144
-2
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,12 @@ def strict_sint_to_fp : SDNode<"ISD::STRICT_SINT_TO_FP",
614614
SDTIntToFPOp, [SDNPHasChain]>;
615615
def strict_uint_to_fp : SDNode<"ISD::STRICT_UINT_TO_FP",
616616
SDTIntToFPOp, [SDNPHasChain]>;
617+
618+
def strict_f16_to_fp : SDNode<"ISD::STRICT_FP16_TO_FP",
619+
SDTIntToFPOp, [SDNPHasChain]>;
620+
def strict_fp_to_f16 : SDNode<"ISD::STRICT_FP_TO_FP16",
621+
SDTFPToIntOp, [SDNPHasChain]>;
622+
617623
def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
618624
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;
619625

@@ -1558,6 +1564,13 @@ def any_fsetccs : PatFrags<(ops node:$lhs, node:$rhs, node:$pred),
15581564
[(strict_fsetccs node:$lhs, node:$rhs, node:$pred),
15591565
(setcc node:$lhs, node:$rhs, node:$pred)]>;
15601566

1567+
def any_f16_to_fp : PatFrags<(ops node:$src),
1568+
[(f16_to_fp node:$src),
1569+
(strict_f16_to_fp node:$src)]>;
1570+
def any_fp_to_f16 : PatFrags<(ops node:$src),
1571+
[(fp_to_f16 node:$src),
1572+
(strict_fp_to_f16 node:$src)]>;
1573+
15611574
multiclass binary_atomic_op_ord {
15621575
def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
15631576
(!cast<SDPatternOperator>(NAME) node:$ptr, node:$val)> {

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,20 @@ static ISD::NodeType GetPromotionOpcode(EVT OpVT, EVT RetVT) {
21812181
report_fatal_error("Attempt at an invalid promotion-related conversion");
21822182
}
21832183

2184+
static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
2185+
if (OpVT == MVT::f16) {
2186+
return ISD::STRICT_FP16_TO_FP;
2187+
} else if (RetVT == MVT::f16) {
2188+
return ISD::STRICT_FP_TO_FP16;
2189+
} else if (OpVT == MVT::bf16) {
2190+
// return ISD::STRICT_BF16_TO_FP;
2191+
} else if (RetVT == MVT::bf16) {
2192+
// return ISD::STRICT_FP_TO_BF16;
2193+
}
2194+
2195+
report_fatal_error("Attempt at an invalid promotion-related conversion");
2196+
}
2197+
21842198
bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
21852199
LLVM_DEBUG(dbgs() << "Promote float operand " << OpNo << ": "; N->dump(&DAG));
21862200
SDValue R = SDValue();
@@ -2416,6 +2430,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
24162430
case ISD::FFREXP: R = PromoteFloatRes_FFREXP(N); break;
24172431

24182432
case ISD::FP_ROUND: R = PromoteFloatRes_FP_ROUND(N); break;
2433+
case ISD::STRICT_FP_ROUND:
2434+
R = PromoteFloatRes_STRICT_FP_ROUND(N);
2435+
break;
24192436
case ISD::LOAD: R = PromoteFloatRes_LOAD(N); break;
24202437
case ISD::SELECT: R = PromoteFloatRes_SELECT(N); break;
24212438
case ISD::SELECT_CC: R = PromoteFloatRes_SELECT_CC(N); break;
@@ -2621,6 +2638,29 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_FP_ROUND(SDNode *N) {
26212638
return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round);
26222639
}
26232640

2641+
// Explicit operation to reduce precision. Reduce the value to half precision
2642+
// and promote it back to the legal type.
2643+
SDValue DAGTypeLegalizer::PromoteFloatRes_STRICT_FP_ROUND(SDNode *N) {
2644+
SDLoc DL(N);
2645+
2646+
SDValue Chain = N->getOperand(0);
2647+
SDValue Op = N->getOperand(1);
2648+
EVT VT = N->getValueType(0);
2649+
EVT OpVT = Op->getValueType(0);
2650+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
2651+
EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
2652+
2653+
// Round promoted float to desired precision
2654+
SDValue Round = DAG.getNode(GetPromotionOpcodeStrict(OpVT, VT), DL,
2655+
DAG.getVTList(IVT, MVT::Other), Chain, Op);
2656+
// Promote it back to the legal output type
2657+
SDValue Res =
2658+
DAG.getNode(GetPromotionOpcodeStrict(VT, NVT), DL,
2659+
DAG.getVTList(NVT, MVT::Other), Round.getValue(1), Round);
2660+
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
2661+
return Res;
2662+
}
2663+
26242664
SDValue DAGTypeLegalizer::PromoteFloatRes_LOAD(SDNode *N) {
26252665
LoadSDNode *L = cast<LoadSDNode>(N);
26262666
EVT VT = N->getValueType(0);

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
165165
case ISD::FP_TO_FP16:
166166
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
167167
break;
168-
168+
case ISD::STRICT_FP_TO_FP16:
169+
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
170+
break;
169171
case ISD::GET_ROUNDING: Res = PromoteIntRes_GET_ROUNDING(N); break;
170172

171173
case ISD::AND:
@@ -787,6 +789,16 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
787789
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
788790
}
789791

792+
SDValue DAGTypeLegalizer::PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N) {
793+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
794+
SDLoc dl(N);
795+
796+
SDValue Res = DAG.getNode(N->getOpcode(), dl, DAG.getVTList(NVT, MVT::Other),
797+
N->getOperand(0), N->getOperand(1));
798+
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
799+
return Res;
800+
}
801+
790802
SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
791803
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
792804
SDLoc dl(N);
@@ -1804,6 +1816,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
18041816
case ISD::FP16_TO_FP:
18051817
case ISD::VP_UINT_TO_FP:
18061818
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
1819+
case ISD::STRICT_FP16_TO_FP:
18071820
case ISD::STRICT_UINT_TO_FP: Res = PromoteIntOp_STRICT_UINT_TO_FP(N); break;
18081821
case ISD::ZERO_EXTEND: Res = PromoteIntOp_ZERO_EXTEND(N); break;
18091822
case ISD::VP_ZERO_EXTEND: Res = PromoteIntOp_VP_ZERO_EXTEND(N); break;

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
326326
SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
327327
SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
328328
SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
329+
SDValue PromoteIntRes_STRICT_FP_TO_FP16_BF16(SDNode *N);
329330
SDValue PromoteIntRes_XRINT(SDNode *N);
330331
SDValue PromoteIntRes_FREEZE(SDNode *N);
331332
SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
@@ -698,6 +699,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
698699
SDValue PromoteFloatRes_ExpOp(SDNode *N);
699700
SDValue PromoteFloatRes_FFREXP(SDNode *N);
700701
SDValue PromoteFloatRes_FP_ROUND(SDNode *N);
702+
SDValue PromoteFloatRes_STRICT_FP_ROUND(SDNode *N);
701703
SDValue PromoteFloatRes_LOAD(SDNode *N);
702704
SDValue PromoteFloatRes_SELECT(SDNode *N);
703705
SDValue PromoteFloatRes_SELECT_CC(SDNode *N);

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ def : Pat <
10971097
multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16_inst_e64> {
10981098
// f16_to_fp patterns
10991099
def : GCNPat <
1100-
(f32 (f16_to_fp i32:$src0)),
1100+
(f32 (any_f16_to_fp i32:$src0)),
11011101
(cvt_f32_f16_inst_e64 SRCMODS.NONE, $src0)
11021102
>;
11031103

@@ -1151,6 +1151,13 @@ multiclass f16_fp_Pats<Instruction cvt_f16_f32_inst_e64, Instruction cvt_f32_f16
11511151
(f16 (uint_to_fp i32:$src)),
11521152
(cvt_f16_f32_inst_e64 SRCMODS.NONE, (V_CVT_F32_U32_e32 VSrc_b32:$src))
11531153
>;
1154+
1155+
// This is only used on targets without half support
1156+
// TODO: Introduce strict variant of AMDGPUfp_to_f16 and share custom lowering
1157+
def : GCNPat <
1158+
(i32 (strict_fp_to_f16 (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
1159+
(cvt_f16_f32_inst_e64 $src0_modifiers, f32:$src0)
1160+
>;
11541161
}
11551162

11561163
let SubtargetPredicate = NotHasTrue16BitInsts in

llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
declare float @llvm.experimental.constrained.fpext.f32.f16(half, metadata) #0
55
declare <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half>, metadata) #0
6+
declare half @llvm.experimental.constrained.fptrunc.f16.f32(float, metadata, metadata) #0
7+
declare <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float>, metadata, metadata) #0
8+
declare float @llvm.fabs.f32(float)
69

710
define float @v_constrained_fpext_f16_to_f32(ptr addrspace(1) %ptr) #0 {
811
; GFX7-LABEL: v_constrained_fpext_f16_to_f32:
@@ -40,4 +43,68 @@ define <2 x float> @v_constrained_fpext_v2f16_to_v2f32(ptr addrspace(1) %ptr) #0
4043
ret <2 x float> %result
4144
}
4245

46+
define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict(float %arg, ptr addrspace(1) %ptr) #0 {
47+
; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict:
48+
; GFX7: ; %bb.0:
49+
; GFX7-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
50+
; GFX7-NEXT: v_cvt_f16_f32_e32 v0, v0
51+
; GFX7-NEXT: v_and_b32_e32 v0, 0xffff, v0
52+
; GFX7-NEXT: v_cvt_f32_f16_e32 v0, v0
53+
; GFX7-NEXT: s_setpc_b64 s[30:31]
54+
%result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
55+
ret void
56+
}
57+
58+
define void @v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict(<2 x float> %arg, ptr addrspace(1) %ptr) #0 {
59+
; GFX7-LABEL: v_constrained_fptrunc_v2f32_to_v2f16_fpexcept_strict:
60+
; GFX7: ; %bb.0:
61+
; GFX7-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
62+
; GFX7-NEXT: v_cvt_f16_f32_e32 v1, v1
63+
; GFX7-NEXT: v_cvt_f16_f32_e32 v0, v0
64+
; GFX7-NEXT: s_mov_b32 s6, 0
65+
; GFX7-NEXT: s_mov_b32 s7, 0xf000
66+
; GFX7-NEXT: v_and_b32_e32 v1, 0xffff, v1
67+
; GFX7-NEXT: v_cvt_f32_f16_e32 v1, v1
68+
; GFX7-NEXT: v_and_b32_e32 v0, 0xffff, v0
69+
; GFX7-NEXT: v_cvt_f32_f16_e32 v0, v0
70+
; GFX7-NEXT: s_mov_b32 s4, s6
71+
; GFX7-NEXT: v_cvt_f16_f32_e32 v1, v1
72+
; GFX7-NEXT: s_mov_b32 s5, s6
73+
; GFX7-NEXT: v_cvt_f16_f32_e32 v0, v0
74+
; GFX7-NEXT: v_lshlrev_b32_e32 v1, 16, v1
75+
; GFX7-NEXT: v_or_b32_e32 v0, v0, v1
76+
; GFX7-NEXT: buffer_store_dword v0, v[2:3], s[4:7], 0 addr64
77+
; GFX7-NEXT: s_waitcnt vmcnt(0)
78+
; GFX7-NEXT: s_setpc_b64 s[30:31]
79+
%result = call <2 x half> @llvm.experimental.constrained.fptrunc.v2f16.v2f32(<2 x float> %arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
80+
store <2 x half> %result, ptr addrspace(1) %ptr
81+
ret void
82+
}
83+
84+
define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg(float %arg, ptr addrspace(1) %ptr) #0 {
85+
; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fneg:
86+
; GFX7: ; %bb.0:
87+
; GFX7-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
88+
; GFX7-NEXT: v_cvt_f16_f32_e64 v0, -v0
89+
; GFX7-NEXT: v_and_b32_e32 v0, 0xffff, v0
90+
; GFX7-NEXT: v_cvt_f32_f16_e32 v0, v0
91+
; GFX7-NEXT: s_setpc_b64 s[30:31]
92+
%neg.arg = fneg float %arg
93+
%result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %neg.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
94+
ret void
95+
}
96+
97+
define void @v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs(float %arg, ptr addrspace(1) %ptr) #0 {
98+
; GFX7-LABEL: v_constrained_fptrunc_f32_to_f16_fpexcept_strict_fabs:
99+
; GFX7: ; %bb.0:
100+
; GFX7-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
101+
; GFX7-NEXT: v_cvt_f16_f32_e64 v0, |v0|
102+
; GFX7-NEXT: v_and_b32_e32 v0, 0xffff, v0
103+
; GFX7-NEXT: v_cvt_f32_f16_e32 v0, v0
104+
; GFX7-NEXT: s_setpc_b64 s[30:31]
105+
%abs.arg = call float @llvm.fabs.f32(float %arg)
106+
%result = call half @llvm.experimental.constrained.fptrunc.f16.f32(float %abs.arg, metadata !"round.tonearest", metadata !"fpexcept.strict")
107+
ret void
108+
}
109+
43110
attributes #0 = { strictfp }

0 commit comments

Comments
 (0)