Skip to content

Commit 37fcb32

Browse files
authored
[RISCV] Add codegen support for Zvfbfmin (#87911)
This patch adds basic codegen support for Zvfbfmin extension.
1 parent 3e54768 commit 37fcb32

19 files changed

+1032
-67
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,23 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10871087
}
10881088
}
10891089

1090+
// TODO: Could we merge some code with zvfhmin?
1091+
if (Subtarget.hasVInstructionsBF16()) {
1092+
for (MVT VT : BF16VecVTs) {
1093+
if (!isTypeLegal(VT))
1094+
continue;
1095+
setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1096+
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1097+
setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
1098+
Custom);
1099+
setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
1100+
ISD::EXTRACT_SUBVECTOR},
1101+
VT, Custom);
1102+
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1103+
// TODO: Promote to fp32.
1104+
}
1105+
}
1106+
10901107
if (Subtarget.hasVInstructionsF32()) {
10911108
for (MVT VT : F32VecVTs) {
10921109
if (!isTypeLegal(VT))
@@ -1302,6 +1319,19 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13021319
continue;
13031320
}
13041321

1322+
if (VT.getVectorElementType() == MVT::bf16) {
1323+
setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1324+
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1325+
setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
1326+
Custom);
1327+
setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR,
1328+
ISD::EXTRACT_SUBVECTOR},
1329+
VT, Custom);
1330+
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
1331+
// TODO: Promote to fp32.
1332+
continue;
1333+
}
1334+
13051335
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
13061336
setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
13071337
Custom);
@@ -2561,6 +2591,10 @@ static bool useRVVForFixedLengthVectorVT(MVT VT,
25612591
if (!Subtarget.hasVInstructionsF16Minimal())
25622592
return false;
25632593
break;
2594+
case MVT::bf16:
2595+
if (!Subtarget.hasVInstructionsBF16())
2596+
return false;
2597+
break;
25642598
case MVT::f32:
25652599
if (!Subtarget.hasVInstructionsF32())
25662600
return false;
@@ -2612,6 +2646,7 @@ static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT,
26122646
case MVT::i16:
26132647
case MVT::i32:
26142648
case MVT::i64:
2649+
case MVT::bf16:
26152650
case MVT::f16:
26162651
case MVT::f32:
26172652
case MVT::f64: {
@@ -8101,8 +8136,10 @@ RISCVTargetLowering::lowerStrictFPExtendOrRoundLike(SDValue Op,
81018136

81028137
// RVV can only widen/truncate fp to types double/half the size as the source.
81038138
if ((VT.getVectorElementType() == MVT::f64 &&
8104-
SrcVT.getVectorElementType() == MVT::f16) ||
8105-
(VT.getVectorElementType() == MVT::f16 &&
8139+
(SrcVT.getVectorElementType() == MVT::f16 ||
8140+
SrcVT.getVectorElementType() == MVT::bf16)) ||
8141+
((VT.getVectorElementType() == MVT::f16 ||
8142+
VT.getVectorElementType() == MVT::bf16) &&
81068143
SrcVT.getVectorElementType() == MVT::f64)) {
81078144
// For double rounding, the intermediate rounding should be round-to-odd.
81088145
unsigned InterConvOpc = Op.getOpcode() == ISD::STRICT_FP_EXTEND
@@ -8146,9 +8183,12 @@ RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op,
81468183
SDValue Src = Op.getOperand(0);
81478184
MVT SrcVT = Src.getSimpleValueType();
81488185

8149-
bool IsDirectExtend = IsExtend && (VT.getVectorElementType() != MVT::f64 ||
8150-
SrcVT.getVectorElementType() != MVT::f16);
8151-
bool IsDirectTrunc = !IsExtend && (VT.getVectorElementType() != MVT::f16 ||
8186+
bool IsDirectExtend =
8187+
IsExtend && (VT.getVectorElementType() != MVT::f64 ||
8188+
(SrcVT.getVectorElementType() != MVT::f16 &&
8189+
SrcVT.getVectorElementType() != MVT::bf16));
8190+
bool IsDirectTrunc = !IsExtend && ((VT.getVectorElementType() != MVT::f16 &&
8191+
VT.getVectorElementType() != MVT::bf16) ||
81528192
SrcVT.getVectorElementType() != MVT::f64);
81538193

81548194
bool IsDirectConv = IsDirectExtend || IsDirectTrunc;

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -355,24 +355,24 @@ defset list<VTypeInfo> AllVectors = {
355355
V_M8, f64, FPR64>;
356356
}
357357
}
358-
}
359358

360-
defset list<VTypeInfo> AllBFloatVectors = {
361-
defset list<VTypeInfo> NoGroupBFloatVectors = {
362-
defset list<VTypeInfo> FractionalGroupBFloatVectors = {
363-
def VBF16MF4: VTypeInfo<vbfloat16mf4_t, vbool64_t, 16, V_MF4, bf16, FPR16>;
364-
def VBF16MF2: VTypeInfo<vbfloat16mf2_t, vbool32_t, 16, V_MF2, bf16, FPR16>;
359+
defset list<VTypeInfo> AllBFloatVectors = {
360+
defset list<VTypeInfo> NoGroupBFloatVectors = {
361+
defset list<VTypeInfo> FractionalGroupBFloatVectors = {
362+
def VBF16MF4: VTypeInfo<vbfloat16mf4_t, vbool64_t, 16, V_MF4, bf16, FPR16>;
363+
def VBF16MF2: VTypeInfo<vbfloat16mf2_t, vbool32_t, 16, V_MF2, bf16, FPR16>;
364+
}
365+
def VBF16M1: VTypeInfo<vbfloat16m1_t, vbool16_t, 16, V_M1, bf16, FPR16>;
366+
}
367+
368+
defset list<GroupVTypeInfo> GroupBFloatVectors = {
369+
def VBF16M2: GroupVTypeInfo<vbfloat16m2_t, vbfloat16m1_t, vbool8_t, 16,
370+
V_M2, bf16, FPR16>;
371+
def VBF16M4: GroupVTypeInfo<vbfloat16m4_t, vbfloat16m1_t, vbool4_t, 16,
372+
V_M4, bf16, FPR16>;
373+
def VBF16M8: GroupVTypeInfo<vbfloat16m8_t, vbfloat16m1_t, vbool2_t, 16,
374+
V_M8, bf16, FPR16>;
365375
}
366-
def VBF16M1: VTypeInfo<vbfloat16m1_t, vbool16_t, 16, V_M1, bf16, FPR16>;
367-
}
368-
369-
defset list<GroupVTypeInfo> GroupBFloatVectors = {
370-
def VBF16M2: GroupVTypeInfo<vbfloat16m2_t, vbfloat16m1_t, vbool8_t, 16,
371-
V_M2, bf16, FPR16>;
372-
def VBF16M4: GroupVTypeInfo<vbfloat16m4_t, vbfloat16m1_t, vbool4_t, 16,
373-
V_M4, bf16, FPR16>;
374-
def VBF16M8: GroupVTypeInfo<vbfloat16m8_t, vbfloat16m1_t, vbool2_t, 16,
375-
V_M8, bf16, FPR16>;
376376
}
377377
}
378378

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,20 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
14951495
fvti.AVL, fvti.Log2SEW, TA_MA)>;
14961496
}
14971497

1498+
foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
1499+
defvar fvti = fvtiToFWti.Vti;
1500+
defvar fwti = fvtiToFWti.Wti;
1501+
let Predicates = [HasVInstructionsBF16] in
1502+
def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
1503+
(!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
1504+
(fvti.Vector (IMPLICIT_DEF)),
1505+
fwti.RegClass:$rs1,
1506+
// Value to indicate no rounding mode change in
1507+
// RISCVInsertReadWriteCSR
1508+
FRM_DYN,
1509+
fvti.AVL, fvti.Log2SEW, TA_MA)>;
1510+
}
1511+
14981512
//===----------------------------------------------------------------------===//
14991513
// Vector Splats
15001514
//===----------------------------------------------------------------------===//

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,6 +2670,20 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
26702670
GPR:$vl, fvti.Log2SEW, TA_MA)>;
26712671
}
26722672

2673+
foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
2674+
defvar fvti = fvtiToFWti.Vti;
2675+
defvar fwti = fvtiToFWti.Wti;
2676+
let Predicates = [HasVInstructionsBF16] in
2677+
def : Pat<(fwti.Vector (any_riscv_fpextend_vl
2678+
(fvti.Vector fvti.RegClass:$rs1),
2679+
(fvti.Mask V0),
2680+
VLOpFrag)),
2681+
(!cast<Instruction>("PseudoVFWCVTBF16_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
2682+
(fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1,
2683+
(fvti.Mask V0),
2684+
GPR:$vl, fvti.Log2SEW, TA_MA)>;
2685+
}
2686+
26732687
// 13.19 Narrowing Floating-Point/Integer Type-Convert Instructions
26742688
defm : VPatNConvertFP2IVL_W_RM<riscv_vfcvt_xu_f_vl, "PseudoVFNCVT_XU_F_W">;
26752689
defm : VPatNConvertFP2IVL_W_RM<riscv_vfcvt_x_f_vl, "PseudoVFNCVT_X_F_W">;
@@ -2714,6 +2728,22 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
27142728
}
27152729
}
27162730

2731+
foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in {
2732+
defvar fvti = fvtiToFWti.Vti;
2733+
defvar fwti = fvtiToFWti.Wti;
2734+
let Predicates = [HasVInstructionsBF16] in
2735+
def : Pat<(fvti.Vector (any_riscv_fpround_vl
2736+
(fwti.Vector fwti.RegClass:$rs1),
2737+
(fwti.Mask V0), VLOpFrag)),
2738+
(!cast<Instruction>("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK")
2739+
(fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
2740+
(fwti.Mask V0),
2741+
// Value to indicate no rounding mode change in
2742+
// RISCVInsertReadWriteCSR
2743+
FRM_DYN,
2744+
GPR:$vl, fvti.Log2SEW, TA_MA)>;
2745+
}
2746+
27172747
// 14. Vector Reduction Operations
27182748

27192749
// 14.1. Vector Single-Width Integer Reduction Instructions

llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple riscv32 -mattr=+m,+d,+zfh,+zvfh,+v -verify-machineinstrs < %s | FileCheck %s
3-
; RUN: llc -mtriple riscv64 -mattr=+m,+d,+zfh,+zvfh,+v -verify-machineinstrs < %s | FileCheck %s
2+
; RUN: llc -mtriple riscv32 -mattr=+m,+d,+zfh,+zvfh,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple riscv64 -mattr=+m,+d,+zfh,+zvfh,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
44

55
define <vscale x 4 x i32> @extract_nxv8i32_nxv4i32_0(<vscale x 8 x i32> %vec) {
66
; CHECK-LABEL: extract_nxv8i32_nxv4i32_0:
@@ -481,6 +481,60 @@ define <vscale x 6 x half> @extract_nxv6f16_nxv12f16_6(<vscale x 12 x half> %in)
481481
ret <vscale x 6 x half> %res
482482
}
483483

484+
define <vscale x 2 x bfloat> @extract_nxv2bf16_nxv16bf16_0(<vscale x 16 x bfloat> %vec) {
485+
; CHECK-LABEL: extract_nxv2bf16_nxv16bf16_0:
486+
; CHECK: # %bb.0:
487+
; CHECK-NEXT: ret
488+
%c = call <vscale x 2 x bfloat> @llvm.vector.extract.nxv2bf16.nxv16bf16(<vscale x 16 x bfloat> %vec, i64 0)
489+
ret <vscale x 2 x bfloat> %c
490+
}
491+
492+
define <vscale x 2 x bfloat> @extract_nxv2bf16_nxv16bf16_2(<vscale x 16 x bfloat> %vec) {
493+
; CHECK-LABEL: extract_nxv2bf16_nxv16bf16_2:
494+
; CHECK: # %bb.0:
495+
; CHECK-NEXT: csrr a0, vlenb
496+
; CHECK-NEXT: srli a0, a0, 2
497+
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
498+
; CHECK-NEXT: vslidedown.vx v8, v8, a0
499+
; CHECK-NEXT: ret
500+
%c = call <vscale x 2 x bfloat> @llvm.vector.extract.nxv2bf16.nxv16bf16(<vscale x 16 x bfloat> %vec, i64 2)
501+
ret <vscale x 2 x bfloat> %c
502+
}
503+
504+
define <vscale x 2 x bfloat> @extract_nxv2bf16_nxv16bf16_4(<vscale x 16 x bfloat> %vec) {
505+
; CHECK-LABEL: extract_nxv2bf16_nxv16bf16_4:
506+
; CHECK: # %bb.0:
507+
; CHECK-NEXT: vmv1r.v v8, v9
508+
; CHECK-NEXT: ret
509+
%c = call <vscale x 2 x bfloat> @llvm.vector.extract.nxv2bf16.nxv16bf16(<vscale x 16 x bfloat> %vec, i64 4)
510+
ret <vscale x 2 x bfloat> %c
511+
}
512+
513+
define <vscale x 6 x bfloat> @extract_nxv6bf16_nxv12bf16_0(<vscale x 12 x bfloat> %in) {
514+
; CHECK-LABEL: extract_nxv6bf16_nxv12bf16_0:
515+
; CHECK: # %bb.0:
516+
; CHECK-NEXT: ret
517+
%res = call <vscale x 6 x bfloat> @llvm.vector.extract.nxv6bf16.nxv12bf16(<vscale x 12 x bfloat> %in, i64 0)
518+
ret <vscale x 6 x bfloat> %res
519+
}
520+
521+
define <vscale x 6 x bfloat> @extract_nxv6bf16_nxv12bf16_6(<vscale x 12 x bfloat> %in) {
522+
; CHECK-LABEL: extract_nxv6bf16_nxv12bf16_6:
523+
; CHECK: # %bb.0:
524+
; CHECK-NEXT: csrr a0, vlenb
525+
; CHECK-NEXT: srli a0, a0, 2
526+
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
527+
; CHECK-NEXT: vslidedown.vx v13, v10, a0
528+
; CHECK-NEXT: vslidedown.vx v12, v9, a0
529+
; CHECK-NEXT: add a1, a0, a0
530+
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
531+
; CHECK-NEXT: vslideup.vx v12, v10, a0
532+
; CHECK-NEXT: vmv2r.v v8, v12
533+
; CHECK-NEXT: ret
534+
%res = call <vscale x 6 x bfloat> @llvm.vector.extract.nxv6bf16.nxv12bf16(<vscale x 12 x bfloat> %in, i64 6)
535+
ret <vscale x 6 x bfloat> %res
536+
}
537+
484538
declare <vscale x 6 x half> @llvm.vector.extract.nxv6f16.nxv12f16(<vscale x 12 x half>, i64)
485539

486540
declare <vscale x 1 x i8> @llvm.vector.extract.nxv1i8.nxv4i8(<vscale x 4 x i8> %vec, i64 %idx)

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fpext-vp.ll

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v -verify-machineinstrs < %s | FileCheck %s
3-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v -verify-machineinstrs < %s | FileCheck %s
4-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v -verify-machineinstrs < %s | FileCheck %s
5-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v -verify-machineinstrs < %s | FileCheck %s
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfh,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfh,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+zvfhmin,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
5+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+zvfhmin,+v,+experimental-zvfbfmin -verify-machineinstrs < %s | FileCheck %s
66

77
declare <2 x float> @llvm.vp.fpext.v2f32.v2f16(<2 x half>, <2 x i1>, i32)
88

@@ -120,3 +120,53 @@ define <32 x double> @vfpext_v32f32_v32f64(<32 x float> %a, <32 x i1> %m, i32 ze
120120
%v = call <32 x double> @llvm.vp.fpext.v32f64.v32f32(<32 x float> %a, <32 x i1> %m, i32 %vl)
121121
ret <32 x double> %v
122122
}
123+
124+
declare <2 x float> @llvm.vp.fpext.v2f32.v2bf16(<2 x bfloat>, <2 x i1>, i32)
125+
126+
define <2 x float> @vfpext_v2bf16_v2f32(<2 x bfloat> %a, <2 x i1> %m, i32 zeroext %vl) {
127+
; CHECK-LABEL: vfpext_v2bf16_v2f32:
128+
; CHECK: # %bb.0:
129+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
130+
; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8, v0.t
131+
; CHECK-NEXT: vmv1r.v v8, v9
132+
; CHECK-NEXT: ret
133+
%v = call <2 x float> @llvm.vp.fpext.v2f32.v2bf16(<2 x bfloat> %a, <2 x i1> %m, i32 %vl)
134+
ret <2 x float> %v
135+
}
136+
137+
define <2 x float> @vfpext_v2bf16_v2f32_unmasked(<2 x bfloat> %a, i32 zeroext %vl) {
138+
; CHECK-LABEL: vfpext_v2bf16_v2f32_unmasked:
139+
; CHECK: # %bb.0:
140+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
141+
; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8
142+
; CHECK-NEXT: vmv1r.v v8, v9
143+
; CHECK-NEXT: ret
144+
%v = call <2 x float> @llvm.vp.fpext.v2f32.v2bf16(<2 x bfloat> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
145+
ret <2 x float> %v
146+
}
147+
148+
declare <2 x double> @llvm.vp.fpext.v2f64.v2bf16(<2 x bfloat>, <2 x i1>, i32)
149+
150+
define <2 x double> @vfpext_v2bf16_v2f64(<2 x bfloat> %a, <2 x i1> %m, i32 zeroext %vl) {
151+
; CHECK-LABEL: vfpext_v2bf16_v2f64:
152+
; CHECK: # %bb.0:
153+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
154+
; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8, v0.t
155+
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
156+
; CHECK-NEXT: vfwcvt.f.f.v v8, v9, v0.t
157+
; CHECK-NEXT: ret
158+
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2bf16(<2 x bfloat> %a, <2 x i1> %m, i32 %vl)
159+
ret <2 x double> %v
160+
}
161+
162+
define <2 x double> @vfpext_v2bf16_v2f64_unmasked(<2 x bfloat> %a, i32 zeroext %vl) {
163+
; CHECK-LABEL: vfpext_v2bf16_v2f64_unmasked:
164+
; CHECK: # %bb.0:
165+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
166+
; CHECK-NEXT: vfwcvtbf16.f.f.v v9, v8
167+
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
168+
; CHECK-NEXT: vfwcvt.f.f.v v8, v9
169+
; CHECK-NEXT: ret
170+
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2bf16(<2 x bfloat> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
171+
ret <2 x double> %v
172+
}

0 commit comments

Comments
 (0)