Skip to content

Commit 865952b

Browse files
authored
[NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2) (#102969)
PTX ISA 8.1 supports FP8 conversions: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt This PR adds the support for: - cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; - cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; - cvt.rn.{.relu}.f16x2.f8x2type d, a; where .f8x2type = { .e4m3x2, .e5m2x2 };
1 parent 7f968e3 commit 865952b

File tree

6 files changed

+220
-0
lines changed

6 files changed

+220
-0
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
584584

585585
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
586586

587+
TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn, "sff", "", AND(SM_89,PTX81))
588+
TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81))
589+
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn, "sff", "", AND(SM_89,PTX81))
590+
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81))
591+
592+
TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81))
593+
TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
594+
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81))
595+
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
596+
597+
TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
598+
TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
599+
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
600+
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
601+
587602
// Bitcast
588603

589604
BUILTIN(__nvvm_bitcast_f2i, "if", "")

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
2323
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
2424
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
25+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
26+
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
27+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
2528

2629
#define __device__ __attribute__((device))
2730
#define __global__ __attribute__((global))
@@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() {
968971
// CHECK: ret void
969972
}
970973

974+
// CHECK-LABEL: nvvm_cvt_sm89
975+
__device__ void nvvm_cvt_sm89() {
976+
#if __CUDA_ARCH__ >= 890
977+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00)
978+
__nvvm_ff_to_e4m3x2_rn(1.0f, 1.0f);
979+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
980+
__nvvm_ff_to_e4m3x2_rn_relu(1.0f, 1.0f);
981+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00)
982+
__nvvm_ff_to_e5m2x2_rn(1.0f, 1.0f);
983+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
984+
__nvvm_ff_to_e5m2x2_rn_relu(1.0f, 1.0f);
985+
986+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
987+
__nvvm_f16x2_to_e4m3x2_rn({1.0f16, 1.0f16});
988+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
989+
__nvvm_f16x2_to_e4m3x2_rn_relu({1.0f16, 1.0f16});
990+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
991+
__nvvm_f16x2_to_e5m2x2_rn({1.0f16, 1.0f16});
992+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
993+
__nvvm_f16x2_to_e5m2x2_rn_relu({1.0f16, 1.0f16});
994+
995+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 18504)
996+
__nvvm_e4m3x2_to_f16x2_rn(0x4848);
997+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 18504)
998+
__nvvm_e4m3x2_to_f16x2_rn_relu(0x4848);
999+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 19532)
1000+
__nvvm_e5m2x2_to_f16x2_rn(0x4c4c);
1001+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 19532)
1002+
__nvvm_e5m2x2_to_f16x2_rn_relu(0x4c4c);
1003+
#endif
1004+
// CHECK: ret void
1005+
}
1006+
9711007
#define NAN32 0x7FBFFFFF
9721008
#define NAN16 (__bf16)0x7FBF
9731009
#define BF16 (__bf16)0.1f

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,33 @@ let TargetPrefix = "nvvm" in {
13121312
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
13131313
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
13141314

1315+
def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">,
1316+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1317+
def int_nvvm_ff_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn_relu">,
1318+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1319+
def int_nvvm_ff_to_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn">,
1320+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1321+
def int_nvvm_ff_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn_relu">,
1322+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1323+
1324+
def int_nvvm_f16x2_to_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn">,
1325+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1326+
def int_nvvm_f16x2_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn_relu">,
1327+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1328+
def int_nvvm_f16x2_to_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn">,
1329+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1330+
def int_nvvm_f16x2_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn_relu">,
1331+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1332+
1333+
def int_nvvm_e4m3x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn">,
1334+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1335+
def int_nvvm_e4m3x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn_relu">,
1336+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1337+
def int_nvvm_e5m2x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn">,
1338+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1339+
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
1340+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1341+
13151342
//
13161343
// Bitcast
13171344
//

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,35 @@ let hasSideEffects = false in {
723723

724724
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>;
725725
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
726+
727+
// FP8 conversions.
728+
multiclass CVT_TO_F8X2<string F8Name> {
729+
def _f32 :
730+
NVPTXInst<(outs Int16Regs:$dst),
731+
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
732+
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
733+
F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>,
734+
Requires<[hasPTX<81>, hasSM<89>]>;
735+
def _f16x2 :
736+
NVPTXInst<(outs Int16Regs:$dst),
737+
(ins Int32Regs:$src, CvtMode:$mode),
738+
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
739+
F8Name, "x2.f16x2 \t$dst, $src;"), []>,
740+
Requires<[hasPTX<81>, hasSM<89>]>;
741+
}
742+
743+
defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">;
744+
defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">;
745+
746+
class CVT_f16x2_fp8<string F8Name> :
747+
NVPTXInst<(outs Int32Regs:$dst),
748+
(ins Int16Regs:$src, CvtMode:$mode),
749+
!strconcat("cvt${mode:base}${mode:relu}.f16x2.",
750+
F8Name, "x2 \t$dst, $src;"), []>,
751+
Requires<[hasPTX<81>, hasSM<89>]>;
752+
753+
def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
754+
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
726755
}
727756

728757
//-----------------------------------

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
15241524
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
15251525
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
15261526

1527+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn Float32Regs:$a, Float32Regs:$b),
1528+
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
1529+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b),
1530+
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
1531+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn Float32Regs:$a, Float32Regs:$b),
1532+
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
1533+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b),
1534+
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
1535+
1536+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn Int32Regs:$a),
1537+
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>;
1538+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu Int32Regs:$a),
1539+
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
1540+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn Int32Regs:$a),
1541+
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>;
1542+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu Int32Regs:$a),
1543+
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
1544+
1545+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn Int16Regs:$a),
1546+
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>;
1547+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu Int16Regs:$a),
1548+
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>;
1549+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
1550+
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>;
1551+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
1552+
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>;
1553+
15271554
//
15281555
// Bitcast
15291556
//
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s
2+
; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %}
3+
4+
; CHECK-LABEL: cvt_rn_e4m3x2_f32
5+
define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) {
6+
; CHECK: cvt.rn.satfinite.e4m3x2.f32
7+
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %f1, float %f2);
8+
ret i16 %val
9+
}
10+
11+
; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32
12+
define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) {
13+
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32
14+
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %f1, float %f2);
15+
ret i16 %val
16+
}
17+
18+
; CHECK-LABEL: cvt_rn_e5m2x2_f32
19+
define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) {
20+
; CHECK: cvt.rn.satfinite.e5m2x2.f32
21+
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %f1, float %f2);
22+
ret i16 %val
23+
}
24+
25+
; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32
26+
define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) {
27+
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32
28+
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %f1, float %f2);
29+
ret i16 %val
30+
}
31+
32+
; CHECK-LABEL: cvt_rn_e4m3x2_f16x2
33+
define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) {
34+
; CHECK: cvt.rn.satfinite.e4m3x2.f16x2
35+
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %in);
36+
ret i16 %val
37+
}
38+
39+
; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2
40+
define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) {
41+
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2
42+
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %in);
43+
ret i16 %val
44+
}
45+
46+
; CHECK-LABEL: cvt_rn_e5m2x2_f16x2
47+
define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) {
48+
; CHECK: cvt.rn.satfinite.e5m2x2.f16x2
49+
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %in);
50+
ret i16 %val
51+
}
52+
53+
; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2
54+
define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) {
55+
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2
56+
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %in);
57+
ret i16 %val
58+
}
59+
60+
; CHECK-LABEL: cvt_rn_f16x2_e4m3x2
61+
define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) {
62+
; CHECK: cvt.rn.f16x2.e4m3x2
63+
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %in);
64+
ret <2 x half> %val
65+
}
66+
67+
; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2
68+
define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) {
69+
; CHECK: cvt.rn.relu.f16x2.e4m3x2
70+
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %in);
71+
ret <2 x half> %val
72+
}
73+
74+
; CHECK-LABEL: cvt_rn_f16x2_e5m2x2
75+
define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) {
76+
; CHECK: cvt.rn.f16x2.e5m2x2
77+
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %in);
78+
ret <2 x half> %val
79+
}
80+
81+
; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2
82+
define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
83+
; CHECK: cvt.rn.relu.f16x2.e5m2x2
84+
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
85+
ret <2 x half> %val
86+
}

0 commit comments

Comments
 (0)