Skip to content

Commit 010eae4

Browse files
committed
[AArch64] Implement intrinsics for SME FP8 FMOPA
1 parent 55dd475 commit 010eae4

File tree

10 files changed

+180
-13
lines changed

10 files changed

+180
-13
lines changed

clang/include/clang/Basic/arm_sme.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,4 +824,14 @@ let SMETargetGuard = "sme-lutv2" in {
824824
def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>;
825825
}
826826

827+
let SMETargetGuard = "sme-f8f32" in {
828+
def SVMOPA_FP8_ZA32 : Inst<"svmopa_za32[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za32",
829+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_3>]>;
830+
}
831+
832+
let SMETargetGuard = "sme-f8f16" in {
833+
def SVMOPA_FP8_ZA16 : Inst<"svmopa_za16[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za16",
834+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_1>]>;
835+
}
836+
827837
} // let SVETargetGuard = InvalidMode

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10183,6 +10183,8 @@ CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
1018310183
case SVETypeFlags::EltTyInt64:
1018410184
return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);
1018510185

10186+
case SVETypeFlags::EltTyMFloat8:
10187+
return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
1018610188
case SVETypeFlags::EltTyFloat16:
1018710189
return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
1018810190
case SVETypeFlags::EltTyBFloat16:
@@ -11234,6 +11236,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
1123411236
BuiltinID == SME::BI__builtin_sme_svstr_za)
1123511237
return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);
1123611238

11239+
// Emit set FPMR for intrinsics that require it
11240+
if (TypeFlags.setsFPMR())
11241+
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr),
11242+
Ops.pop_back_val());
1123711243
// Handle builtins which require their multi-vector operands to be swapped
1123811244
swapCommutativeSMEOperands(BuiltinID, Ops);
1123911245

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// REQUIRES: aarch64-registered-target
3+
4+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
6+
// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
7+
// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
8+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
9+
10+
#include <arm_sme.h>
11+
12+
#ifdef SVE_OVERLOADED_FORMS
13+
#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
14+
#else
15+
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
16+
#endif
17+
18+
19+
// CHECK-LABEL: define dso_local void @test_svmopa_za16_mf8_m(
20+
// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
21+
// CHECK-NEXT: [[ENTRY:.*:]]
22+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
23+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
24+
// CHECK-NEXT: ret void
25+
//
26+
// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za16_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
27+
// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
28+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
29+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
30+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
31+
// CPP-CHECK-NEXT: ret void
32+
//
33+
void test_svmopa_za16_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
34+
svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
35+
SVE_ACLE_FUNC(svmopa_za16,_mf8,_m_fpm)(1, pn, pm, zn, zm, fpmr);
36+
}
37+
38+
// CHECK-LABEL: define dso_local void @test_svmopa_za32_mf8_m(
39+
// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
40+
// CHECK-NEXT: [[ENTRY:.*:]]
41+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
42+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
43+
// CHECK-NEXT: ret void
44+
//
45+
// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za32_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
46+
// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
47+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
48+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
49+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
50+
// CPP-CHECK-NEXT: ret void
51+
//
52+
void test_svmopa_za32_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
53+
svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
54+
SVE_ACLE_FUNC(svmopa_za32,_mf8,_m_fpm)(3, pn, pm, zn, zm, fpmr);
55+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -target-feature +sme2 -target-feature +sme-f8f16 -target-feature +sme-f8f32 -fsyntax-only -verify %s
2+
3+
// REQUIRES: aarch64-registered-target
4+
5+
#include <arm_sme.h>
6+
7+
void test_svmopa(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
8+
fpm_t fpmr) __arm_streaming __arm_inout("za") {
9+
// expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 1]}}
10+
svmopa_za16_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
11+
// expected-error@+1 {{argument value 2 is outside the valid range [0, 1]}}
12+
svmopa_za16_mf8_m_fpm(2, pn, pm, zn, zm, fpmr);
13+
14+
// expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
15+
svmopa_za32_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
16+
// expected-error@+1 {{argument value 4 is outside the valid range [0, 3]}}
17+
svmopa_za32_mf8_m_fpm(4, pn, pm, zn, zm, fpmr);
18+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -verify -emit-llvm-only %s
2+
3+
// REQUIRES: aarch64-registered-target
4+
5+
#include <arm_sme.h>
6+
7+
void test_features(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
8+
fpm_t fpmr) __arm_streaming __arm_inout("za") {
9+
// expected-error@+1 {{'svmopa_za16_mf8_m_fpm' needs target feature sme,sme-f8f16}}
10+
svmopa_za16_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
11+
// expected-error@+1 {{'svmopa_za32_mf8_m_fpm' needs target feature sme,sme-f8f32}}
12+
svmopa_za32_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
13+
}

clang/utils/TableGen/SveEmitter.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,6 @@ void SVEType::applyTypespec(StringRef TS) {
587587
ElementBitwidth = 16;
588588
break;
589589
case 'm':
590-
Signed = false;
591590
MFloat = true;
592591
Float = false;
593592
BFloat = false;
@@ -702,6 +701,7 @@ void SVEType::applyModifier(char Mod) {
702701
Svcount = false;
703702
Float = false;
704703
BFloat = false;
704+
MFloat = false;
705705
ElementBitwidth = Bitwidth = 64;
706706
NumVectors = 0;
707707
Signed = false;
@@ -712,6 +712,7 @@ void SVEType::applyModifier(char Mod) {
712712
Svcount = false;
713713
Float = false;
714714
BFloat = false;
715+
MFloat = false;
715716
ElementBitwidth = Bitwidth = 32;
716717
NumVectors = 0;
717718
Signed = true;
@@ -723,6 +724,7 @@ void SVEType::applyModifier(char Mod) {
723724
Svcount = false;
724725
Float = false;
725726
BFloat = false;
727+
MFloat = false;
726728
ElementBitwidth = Bitwidth = 32;
727729
NumVectors = 0;
728730
Signed = true;
@@ -735,6 +737,7 @@ void SVEType::applyModifier(char Mod) {
735737
Signed = true;
736738
Float = false;
737739
BFloat = false;
740+
MFloat = false;
738741
ElementBitwidth = Bitwidth = 32;
739742
NumVectors = 0;
740743
break;
@@ -744,6 +747,7 @@ void SVEType::applyModifier(char Mod) {
744747
Signed = true;
745748
Float = false;
746749
BFloat = false;
750+
MFloat = false;
747751
ElementBitwidth = Bitwidth = 64;
748752
NumVectors = 0;
749753
break;
@@ -753,6 +757,7 @@ void SVEType::applyModifier(char Mod) {
753757
Signed = false;
754758
Float = false;
755759
BFloat = false;
760+
MFloat = false;
756761
ElementBitwidth = Bitwidth = 32;
757762
NumVectors = 0;
758763
break;
@@ -765,6 +770,7 @@ void SVEType::applyModifier(char Mod) {
765770
Signed = false;
766771
Float = false;
767772
BFloat = false;
773+
MFloat = false;
768774
ElementBitwidth = Bitwidth = 64;
769775
NumVectors = 0;
770776
break;
@@ -783,25 +789,29 @@ void SVEType::applyModifier(char Mod) {
783789
case 'g':
784790
Signed = false;
785791
Float = false;
792+
MFloat = false;
786793
BFloat = false;
787794
ElementBitwidth = 64;
788795
break;
789796
case '[':
790797
Signed = false;
791798
Float = false;
792799
BFloat = false;
800+
MFloat = false;
793801
ElementBitwidth = 8;
794802
break;
795803
case 't':
796804
Signed = true;
797805
Float = false;
798806
BFloat = false;
807+
MFloat = false;
799808
ElementBitwidth = 32;
800809
break;
801810
case 'z':
802811
Signed = false;
803812
Float = false;
804813
BFloat = false;
814+
MFloat = false;
805815
ElementBitwidth = 32;
806816
break;
807817
case 'O':
@@ -815,6 +825,7 @@ void SVEType::applyModifier(char Mod) {
815825
Svcount = false;
816826
Float = true;
817827
BFloat = false;
828+
MFloat = false;
818829
ElementBitwidth = 32;
819830
break;
820831
case 'N':
@@ -922,6 +933,7 @@ void SVEType::applyModifier(char Mod) {
922933
Predicate = false;
923934
Svcount = false;
924935
Float = false;
936+
MFloat = false;
925937
BFloat = true;
926938
ElementBitwidth = 16;
927939
break;
@@ -932,6 +944,7 @@ void SVEType::applyModifier(char Mod) {
932944
NumVectors = 0;
933945
Float = false;
934946
BFloat = false;
947+
MFloat = false;
935948
break;
936949
case '~':
937950
Float = false;

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,6 +2983,13 @@ let TargetPrefix = "aarch64" in {
29832983
LLVMMatchType<0>,
29842984
llvm_anyvector_ty], [ImmArg<ArgIndex<0>>]>;
29852985

2986+
class SME_FP8_OuterProduct_Intrinsic
2987+
: DefaultAttrsIntrinsic<[],
2988+
[llvm_i32_ty,
2989+
llvm_nxv16i1_ty, llvm_nxv16i1_ty,
2990+
llvm_nxv16i8_ty, llvm_nxv16i8_ty],
2991+
[ImmArg<ArgIndex<0>>, IntrInaccessibleMemOnly, IntrHasSideEffects]>;
2992+
29862993
def int_aarch64_sme_mopa : SME_OuterProduct_Intrinsic;
29872994
def int_aarch64_sme_mops : SME_OuterProduct_Intrinsic;
29882995

@@ -2998,6 +3005,10 @@ let TargetPrefix = "aarch64" in {
29983005
def int_aarch64_sme_usmopa_wide : SME_OuterProduct_Intrinsic;
29993006
def int_aarch64_sme_usmops_wide : SME_OuterProduct_Intrinsic;
30003007

3008+
// FP8 outer product
3009+
def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic;
3010+
def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic;
3011+
30013012
class SME_AddVectorToTile_Intrinsic
30023013
: DefaultAttrsIntrinsic<[],
30033014
[llvm_i32_ty,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -990,31 +990,30 @@ defm FDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_f8f16<"fdot", 0b11
990990
defm FDOT_VG4_M4ZZI_BtoH : sme2p1_multi_vec_array_vg4_index_f8f16<"fdot", 0b100, ZZZZ_b_mul_r, ZPR4b8>;
991991
defm FDOT_VG2_M2ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>;
992992
defm FDOT_VG4_M4ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>;
993-
// TODO: Replace nxv16i8 by nxv16f8
993+
994994
defm FDOT_VG2_M2Z2Z_BtoH : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
995995
defm FDOT_VG4_M4Z4Z_BtoH : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
996996

997997
def FMLAL_MZZI_BtoH : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>;
998998
defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>;
999999
defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>;
10001000
def FMLAL_VG2_MZZ_BtoH : sme2_mla_long_array_single_16b<"fmlal">;
1001-
// TODO: Replace nxv16i8 by nxv16f8
1001+
10021002
defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>;
10031003
defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>;
10041004
defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
10051005
defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
10061006

1007-
defm FMOPA_MPPZZ_BtoH : sme2p1_fmop_tile_f8f16<"fmopa", 0b1, 0b0, 0b01>;
1008-
1007+
defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
10091008
} //[HasSMEF8F16]
10101009

10111010
let Predicates = [HasSMEF8F32] in {
1012-
// TODO : Replace nxv16i8 by nxv16f8
1011+
10131012
defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
10141013
defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
10151014
defm FDOT_VG2_M2ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>;
10161015
defm FDOT_VG4_M4ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>;
1017-
// TODO : Replace nxv16i8 by nxv16f8
1016+
10181017
defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
10191018
defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
10201019

@@ -1024,16 +1023,14 @@ def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>;
10241023
defm FMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"fmlall", 0b01, 0b000, null_frag>;
10251024
defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>;
10261025
defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>;
1027-
// TODO: Replace nxv16i8 by nxv16f8
1026+
10281027
defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>;
10291028
defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>;
10301029
defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>;
10311030
defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
10321031
defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
10331032

1034-
1035-
defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>;
1036-
1033+
defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
10371034
} //[HasSMEF8F32]
10381035

10391036
let Predicates = [HasSME2, HasSVEBFSCALE] in {

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,21 @@ multiclass sme_outer_product_fp32<bit S, bits<2> sz, ZPRRegOp zpr_ty, string mne
305305
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_3, nxv4i1, nxv4f32>;
306306
}
307307

308+
multiclass sme2_fp8_fmopa_za32<string mnemonic, SDPatternOperator intrinsic> {
309+
def NAME : sme_fp_outer_product_inst<0, 0b01, 0b00, TileOp32, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
310+
bits<2> ZAda;
311+
let Inst{1-0} = ZAda;
312+
let Inst{2} = 0b0;
313+
314+
let Uses = [FPMR, FPCR];
315+
}
316+
317+
let mayStore = 1, mayLoad = 1 in
318+
def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
319+
320+
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_3, nxv16i1, nxv16i8>;
321+
}
322+
308323
multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> {
309324
def NAME : sme_fp_outer_product_inst<S, 0b10, 0b00, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
310325
bits<3> ZAda;
@@ -316,12 +331,19 @@ multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op>
316331
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_7, nxv2i1, nxv2f64>;
317332
}
318333

319-
multiclass sme2p1_fmop_tile_f8f16<string mnemonic, bit bf, bit s, bits<2> op> {
320-
def NAME : sme_fp_outer_product_inst<s, {0,bf}, op, TileOp16, ZPR8, mnemonic> {
334+
multiclass sme2_fp8_fmopa_za16<string mnemonic, SDPatternOperator intrinsic> {
335+
def NAME : sme_fp_outer_product_inst<0, {0, 0b1}, 0b01, TileOp16, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
321336
bits<1> ZAda;
322337
let Inst{2-1} = 0b00;
323338
let Inst{0} = ZAda;
339+
340+
let Uses = [FPMR, FPCR];
324341
}
342+
343+
let mayStore = 1, mayLoad = 1 in
344+
def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileH>, SMEPseudo2Instr<NAME, 0>;
345+
346+
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_1, nxv16i1, nxv16i8>;
325347
}
326348

327349
multiclass sme2p1_fmop_tile_fp16<string mnemonic, bit bf, bit s, ValueType vt, SDPatternOperator intrinsic = null_frag> {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme-f8f16,+sme-f8f32 -force-streaming < %s | FileCheck %s
3+
4+
define void @test_fmopa_16(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) {
5+
; CHECK-LABEL: test_fmopa_16:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: fmopa za1.h, p0/m, p1/m, z0.b, z1.b
8+
; CHECK-NEXT: ret
9+
call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm,
10+
<vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm)
11+
ret void
12+
}
13+
14+
define void @test_fmopa_32(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) #0 {
15+
; CHECK-LABEL: test_fmopa_32:
16+
; CHECK: // %bb.0:
17+
; CHECK-NEXT: fmopa za3.s, p0/m, p1/m, z0.b, z1.b
18+
; CHECK-NEXT: ret
19+
call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm,
20+
<vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm)
21+
ret void
22+
}

0 commit comments

Comments
 (0)