Skip to content

[AArch64][SVE] Fix definition of bfloat fcvt intrinsics. #110281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions clang/include/clang/Basic/arm_sve.td
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,6 @@ defm SVFCVTZS_S64_F16 : SInstCvtMXZ<"svcvt_s64[_f16]", "ddPO", "dPO", "l", "aar
defm SVFCVTZS_S32_F32 : SInstCvtMXZ<"svcvt_s32[_f32]", "ddPM", "dPM", "i", "aarch64_sve_fcvtzs", [IsOverloadCvt]>;
defm SVFCVTZS_S64_F32 : SInstCvtMXZ<"svcvt_s64[_f32]", "ddPM", "dPM", "l", "aarch64_sve_fcvtzs_i64f32">;

let SVETargetGuard = "sve,bf16", SMETargetGuard = "sme,bf16" in {
defm SVCVT_BF16_F32 : SInstCvtMXZ<"svcvt_bf16[_f32]", "ddPM", "dPM", "b", "aarch64_sve_fcvt_bf16f32">;
def SVCVTNT_BF16_F32 : SInst<"svcvtnt_bf16[_f32]", "ddPM", "b", MergeOp1, "aarch64_sve_fcvtnt_bf16f32", [IsOverloadNone, VerifyRuntimeMode]>;
}

// svcvt_s##_f64
defm SVFCVTZS_S32_F64 : SInstCvtMXZ<"svcvt_s32[_f64]", "ttPd", "tPd", "d", "aarch64_sve_fcvtzs_i32f64">;
defm SVFCVTZS_S64_F64 : SInstCvtMXZ<"svcvt_s64[_f64]", "ddPN", "dPN", "l", "aarch64_sve_fcvtzs", [IsOverloadCvt]>;
Expand Down Expand Up @@ -1003,19 +998,26 @@ defm SVFCVT_F32_F64 : SInstCvtMXZ<"svcvt_f32[_f64]", "MMPd", "MPd", "d", "aarc
defm SVFCVT_F64_F16 : SInstCvtMXZ<"svcvt_f64[_f16]", "ddPO", "dPO", "d", "aarch64_sve_fcvt_f64f16">;
defm SVFCVT_F64_F32 : SInstCvtMXZ<"svcvt_f64[_f32]", "ddPM", "dPM", "d", "aarch64_sve_fcvt_f64f32">;

let SVETargetGuard = "sve2", SMETargetGuard = "sme" in {
defm SVCVTLT_F32 : SInstCvtMX<"svcvtlt_f32[_f16]", "ddPh", "dPh", "f", "aarch64_sve_fcvtlt_f32f16">;
defm SVCVTLT_F64 : SInstCvtMX<"svcvtlt_f64[_f32]", "ddPh", "dPh", "d", "aarch64_sve_fcvtlt_f64f32">;
let SVETargetGuard = "sve,bf16", SMETargetGuard = "sme,bf16" in {
defm SVCVT_BF16_F32 : SInstCvtMXZ<"svcvt_bf16[_f32]", "$$Pd", "$Pd", "f", "aarch64_sve_fcvt_bf16f32_v2">;

def SVCVTNT_BF16_F32 : SInst<"svcvtnt_bf16[_f32]", "$$Pd", "f", MergeOp1, "aarch64_sve_fcvtnt_bf16f32_v2", [IsOverloadNone, VerifyRuntimeMode]>;
// SVCVTNT_X_BF16_F32 : Implemented as macro by SveEmitter.cpp
}

defm SVCVTX_F32 : SInstCvtMXZ<"svcvtx_f32[_f64]", "MMPd", "MPd", "d", "aarch64_sve_fcvtx_f32f64">;
let SVETargetGuard = "sve2", SMETargetGuard = "sme" in {
defm SVCVTLT_F32_F16 : SInstCvtMX<"svcvtlt_f32[_f16]", "ddPh", "dPh", "f", "aarch64_sve_fcvtlt_f32f16">;
defm SVCVTLT_F64_F32 : SInstCvtMX<"svcvtlt_f64[_f32]", "ddPh", "dPh", "d", "aarch64_sve_fcvtlt_f64f32">;

def SVCVTNT_F32 : SInst<"svcvtnt_f16[_f32]", "hhPd", "f", MergeOp1, "aarch64_sve_fcvtnt_f16f32", [IsOverloadNone, VerifyRuntimeMode]>;
def SVCVTNT_F64 : SInst<"svcvtnt_f32[_f64]", "hhPd", "d", MergeOp1, "aarch64_sve_fcvtnt_f32f64", [IsOverloadNone, VerifyRuntimeMode]>;
// SVCVTNT_X : Implemented as macro by SveEmitter.cpp
defm SVCVTX_F32_F64 : SInstCvtMXZ<"svcvtx_f32[_f64]", "MMPd", "MPd", "d", "aarch64_sve_fcvtx_f32f64">;

def SVCVTXNT_F32 : SInst<"svcvtxnt_f32[_f64]", "MMPd", "d", MergeOp1, "aarch64_sve_fcvtxnt_f32f64", [IsOverloadNone, VerifyRuntimeMode]>;
// SVCVTXNT_X_F32 : Implemented as macro by SveEmitter.cpp
def SVCVTNT_F16_F32 : SInst<"svcvtnt_f16[_f32]", "hhPd", "f", MergeOp1, "aarch64_sve_fcvtnt_f16f32", [IsOverloadNone, VerifyRuntimeMode]>;
def SVCVTNT_F32_F64 : SInst<"svcvtnt_f32[_f64]", "hhPd", "d", MergeOp1, "aarch64_sve_fcvtnt_f32f64", [IsOverloadNone, VerifyRuntimeMode]>;
// SVCVTNT_X_F16_F32 : Implemented as macro by SveEmitter.cpp
// SVCVTNT_X_F32_F64 : Implemented as macro by SveEmitter.cpp

def SVCVTXNT_F32_F64 : SInst<"svcvtxnt_f32[_f64]", "MMPd", "d", MergeOp1, "aarch64_sve_fcvtxnt_f32f64", [IsOverloadNone, VerifyRuntimeMode]>;
// SVCVTXNT_X_F32_F64 : Implemented as macro by SveEmitter.cpp
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
24 changes: 12 additions & 12 deletions clang/test/CodeGen/aarch64-sve-intrinsics/acle_sve_cvt-bfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

// CHECK-LABEL: @test_svcvt_bf16_f32_x(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> undef, <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> undef, <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
// CPP-CHECK-LABEL: @_Z21test_svcvt_bf16_f32_xu10__SVBool_tu13__SVFloat32_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> undef, <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> undef, <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
svbfloat16_t test_svcvt_bf16_f32_x(svbool_t pg, svfloat32_t op) MODE_ATTR {
Expand All @@ -40,14 +40,14 @@ svbfloat16_t test_svcvt_bf16_f32_x(svbool_t pg, svfloat32_t op) MODE_ATTR {

// CHECK-LABEL: @test_svcvt_bf16_f32_z(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> zeroinitializer, <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> zeroinitializer, <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
// CPP-CHECK-LABEL: @_Z21test_svcvt_bf16_f32_zu10__SVBool_tu13__SVFloat32_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> zeroinitializer, <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> zeroinitializer, <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
svbfloat16_t test_svcvt_bf16_f32_z(svbool_t pg, svfloat32_t op) MODE_ATTR {
Expand All @@ -56,14 +56,14 @@ svbfloat16_t test_svcvt_bf16_f32_z(svbool_t pg, svfloat32_t op) MODE_ATTR {

// CHECK-LABEL: @test_svcvt_bf16_f32_m(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> [[INACTIVE:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> [[INACTIVE:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
// CPP-CHECK-LABEL: @_Z21test_svcvt_bf16_f32_mu14__SVBfloat16_tu10__SVBool_tu13__SVFloat32_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32(<vscale x 8 x bfloat> [[INACTIVE:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvt.bf16f32.v2(<vscale x 8 x bfloat> [[INACTIVE:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
svbfloat16_t test_svcvt_bf16_f32_m(svbfloat16_t inactive, svbool_t pg, svfloat32_t op) MODE_ATTR {
Expand Down
16 changes: 8 additions & 8 deletions clang/test/CodeGen/aarch64-sve-intrinsics/acle_sve_cvtnt.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

// CHECK-LABEL: @test_svcvtnt_bf16_f32_x(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32.v2(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
// CPP-CHECK-LABEL: @_Z23test_svcvtnt_bf16_f32_xu14__SVBfloat16_tu10__SVBool_tu13__SVFloat32_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32.v2(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
svbfloat16_t test_svcvtnt_bf16_f32_x(svbfloat16_t even, svbool_t pg, svfloat32_t op) MODE_ATTR {
Expand All @@ -40,14 +40,14 @@ svbfloat16_t test_svcvtnt_bf16_f32_x(svbfloat16_t even, svbool_t pg, svfloat32_t

// CHECK-LABEL: @test_svcvtnt_bf16_f32_m(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32.v2(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
// CPP-CHECK-LABEL: @_Z23test_svcvtnt_bf16_f32_mu14__SVBfloat16_tu10__SVBool_tu13__SVFloat32_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 8 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.aarch64.sve.fcvtnt.bf16f32.v2(<vscale x 8 x bfloat> [[EVEN:%.*]], <vscale x 4 x i1> [[TMP0]], <vscale x 4 x float> [[OP:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 8 x bfloat> [[TMP1]]
//
svbfloat16_t test_svcvtnt_bf16_f32_m(svbfloat16_t even, svbool_t pg, svfloat32_t op) MODE_ATTR {
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -2184,8 +2184,8 @@ def int_aarch64_sve_fcvtzs_i32f64 : Builtin_SVCVT<llvm_nxv4i32_ty, llvm_nxv2i1
def int_aarch64_sve_fcvtzs_i64f16 : Builtin_SVCVT<llvm_nxv2i64_ty, llvm_nxv2i1_ty, llvm_nxv8f16_ty>;
def int_aarch64_sve_fcvtzs_i64f32 : Builtin_SVCVT<llvm_nxv2i64_ty, llvm_nxv2i1_ty, llvm_nxv4f32_ty>;

def int_aarch64_sve_fcvt_bf16f32 : Builtin_SVCVT<llvm_nxv8bf16_ty, llvm_nxv8i1_ty, llvm_nxv4f32_ty>;
def int_aarch64_sve_fcvtnt_bf16f32 : Builtin_SVCVT<llvm_nxv8bf16_ty, llvm_nxv8i1_ty, llvm_nxv4f32_ty>;
def int_aarch64_sve_fcvt_bf16f32_v2 : Builtin_SVCVT<llvm_nxv8bf16_ty, llvm_nxv4i1_ty, llvm_nxv4f32_ty>;
def int_aarch64_sve_fcvtnt_bf16f32_v2 : Builtin_SVCVT<llvm_nxv8bf16_ty, llvm_nxv4i1_ty, llvm_nxv4f32_ty>;

def int_aarch64_sve_fcvtzu_i32f16 : Builtin_SVCVT<llvm_nxv4i32_ty, llvm_nxv4i1_ty, llvm_nxv8f16_ty>;
def int_aarch64_sve_fcvtzu_i32f64 : Builtin_SVCVT<llvm_nxv4i32_ty, llvm_nxv2i1_ty, llvm_nxv2f64_ty>;
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,12 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
return false; // No other 'aarch64.sve.bf*'.
}

// 'aarch64.sve.fcvt.bf16f32' || 'aarch64.sve.fcvtnt.bf16f32'
if (Name == "fcvt.bf16f32" || Name == "fcvtnt.bf16f32") {
NewFn = nullptr;
return true;
}

if (Name.consume_front("addqv")) {
// 'aarch64.sve.addqv'.
if (!F->getReturnType()->isFPOrFPVectorTy())
Expand Down Expand Up @@ -4072,6 +4078,35 @@ static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F,
return Rep;
}

static Value *upgradeAArch64IntrinsicCall(StringRef Name, CallBase *CI,
Function *F, IRBuilder<> &Builder) {
Intrinsic::ID NewID =
StringSwitch<Intrinsic::ID>(Name)
.Case("sve.fcvt.bf16f32", Intrinsic::aarch64_sve_fcvt_bf16f32_v2)
.Case("sve.fcvtnt.bf16f32", Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2)
.Default(Intrinsic::not_intrinsic);
if (NewID == Intrinsic::not_intrinsic)
llvm_unreachable("Unhandled Intrinsic!");

SmallVector<Value *, 3> Args(CI->args());

// The original intrinsics incorrectly used a predicate based on the smallest
// element type rather than the largest.
Type *BadPredTy = ScalableVectorType::get(Builder.getInt1Ty(), 8);
Type *GoodPredTy = ScalableVectorType::get(Builder.getInt1Ty(), 4);

if (Args[1]->getType() != BadPredTy)
llvm_unreachable("Unexpected predicate type!");

Args[1] = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_to_svbool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we won't get any IR verification of the old intrinsic, is it worth asserting or reporting an error if the predicate isn't <vscale x 8 x i1> here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

BadPredTy, Args[1]);
Args[1] = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
GoodPredTy, Args[1]);

Function *NewF = Intrinsic::getDeclaration(CI->getModule(), NewID);
return Builder.CreateCall(NewF, Args, CI->getName());
}

static Value *upgradeARMIntrinsicCall(StringRef Name, CallBase *CI, Function *F,
IRBuilder<> &Builder) {
if (Name == "mve.vctp64.old") {
Expand Down Expand Up @@ -4325,6 +4360,7 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {

bool IsX86 = Name.consume_front("x86.");
bool IsNVVM = Name.consume_front("nvvm.");
bool IsAArch64 = Name.consume_front("aarch64.");
bool IsARM = Name.consume_front("arm.");
bool IsAMDGCN = Name.consume_front("amdgcn.");
bool IsDbg = Name.consume_front("dbg.");
Expand All @@ -4336,6 +4372,8 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder);
} else if (IsX86) {
Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder);
} else if (IsAArch64) {
Rep = upgradeAArch64IntrinsicCall(Name, CI, F, Builder);
} else if (IsARM) {
Rep = upgradeARMIntrinsicCall(Name, CI, F, Builder);
} else if (IsAMDGCN) {
Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5549,10 +5549,17 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
"Only expect to cast between legal scalable predicate types!");

// Return the operand if the cast isn't changing type,
// e.g. <n x 16 x i1> -> <n x 16 x i1>
if (InVT == VT)
return Op;

// Look through casts to <vscale x 16 x i1> when their input has more lanes
// than VT. This will increase the chances of removing casts that introduce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps better written as casts to <vscale x 16 x i1> in case it's confusing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I knew something was odd when I was writing it but couldn't figure it out :) I'll change it to MVT::nxv16i1 if that's ok?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, plus removed another redundant reference to <n x 16 x i1>.

// new lanes, which have to be explicitly zero'd.
if (Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
Op.getConstantOperandVal(0) == Intrinsic::aarch64_sve_convert_to_svbool &&
Op.getOperand(1).getValueType().bitsGT(VT))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that should have been caught by the SVEIntrinsicOpts pass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. At the IR level the conversion is a two step process because you have to cast through the "svbool" type. i.e. There is no intrinsic that will convert <vscale x 8 x i1> to <vscale x 4 x i1>.

Op = Op.getOperand(1);

SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);

// We only have to zero the lanes if new lanes are being defined, e.g. when
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2425,8 +2425,8 @@ let Predicates = [HasBF16, HasSVEorSME] in {
defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>;
defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>;
defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>;
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>;
defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>;
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32_v2, AArch64fcvtr_mt>;
defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32_v2>;
} // End HasBF16, HasSVEorSME

let Predicates = [HasSVEorSME] in {
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
switch (IID) {
default:
break;
case Intrinsic::aarch64_sve_fcvt_bf16f32:
case Intrinsic::aarch64_sve_fcvt_bf16f32_v2:
case Intrinsic::aarch64_sve_fcvt_f16f32:
case Intrinsic::aarch64_sve_fcvt_f16f64:
case Intrinsic::aarch64_sve_fcvt_f32f16:
Expand Down Expand Up @@ -2188,7 +2188,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_ucvtf_f32i64:
case Intrinsic::aarch64_sve_ucvtf_f64i32:
return instCombineSVEAllOrNoActiveUnary(IC, II);
case Intrinsic::aarch64_sve_fcvtnt_bf16f32:
case Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2:
case Intrinsic::aarch64_sve_fcvtnt_f16f32:
case Intrinsic::aarch64_sve_fcvtnt_f32f64:
case Intrinsic::aarch64_sve_fcvtxnt_f32f64:
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -8811,7 +8811,7 @@ multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op,
SDPatternOperator ir_op = null_frag> {
def NAME : sve_bfloat_convert<N, asm>;

def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>;
}
Expand Down
Loading
Loading