Skip to content

Commit 42ea605

Browse files
[fixup] Fix SIMD register counting
* correctly count SIMD vectors * handle the new mfloat8x8_t and mfloat8x16_t types and tests updated accordingly.
1 parent dfef87f commit 42ea605

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -368,18 +368,34 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
368368
if (EIT->getNumBits() > 128)
369369
return getNaturalAlignIndirect(Ty, false);
370370

371-
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
372-
if (BT->isSVEBool() || BT->isSVECount())
373-
NPRN = std::min(NPRN + 1, 4u);
374-
else if (BT->getKind() == BuiltinType::SveBoolx2)
375-
NPRN = std::min(NPRN + 2, 4u);
376-
else if (BT->getKind() == BuiltinType::SveBoolx4)
377-
NPRN = std::min(NPRN + 4, 4u);
378-
else if (BT->isFloatingPoint() || BT->isVectorType())
371+
if (Ty->isVectorType())
372+
NSRN = std::min(NSRN + 1, 8u);
373+
else if (const auto *BT = Ty->getAs<BuiltinType>()) {
374+
if (BT->isFloatingPoint())
379375
NSRN = std::min(NSRN + 1, 8u);
380-
else if (BT->isSVESizelessBuiltinType())
381-
NSRN = std::min(
382-
NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors, 8u);
376+
else {
377+
switch (BT->getKind()) {
378+
case BuiltinType::MFloat8x8:
379+
case BuiltinType::MFloat8x16:
380+
NSRN = std::min(NSRN + 1, 8u);
381+
break;
382+
case BuiltinType::SveBool:
383+
case BuiltinType::SveCount:
384+
NPRN = std::min(NPRN + 1, 4u);
385+
break;
386+
case BuiltinType::SveBoolx2:
387+
NPRN = std::min(NPRN + 2, 4u);
388+
break;
389+
case BuiltinType::SveBoolx4:
390+
NPRN = std::min(NPRN + 4, 4u);
391+
break;
392+
default:
393+
if (BT->isSVESizelessBuiltinType())
394+
NSRN = std::min(
395+
NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors,
396+
8u);
397+
}
398+
}
383399
}
384400

385401
return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
@@ -615,7 +631,8 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
615631
// but with the difference that any floating-point type is allowed,
616632
// including __fp16.
617633
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
618-
if (BT->isFloatingPoint())
634+
if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
635+
BT->getKind() == BuiltinType::MFloat8x8)
619636
return true;
620637
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
621638
if (auto Kind = VT->getVectorKind();

clang/test/CodeGen/aarch64-pure-scalable-args.c

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
// REQUIRES: aarch64-registered-target
66

7+
#include <arm_neon.h>
78
#include <arm_sve.h>
89
#include <stdarg.h>
910

@@ -16,6 +17,10 @@ typedef struct {
1617
float f[4];
1718
} HFA;
1819

20+
typedef struct {
21+
mfloat8x16_t f[4];
22+
} HVA;
23+
1924
// Pure Scalable Type, needs 4 Z-regs, 2 P-regs
2025
typedef struct {
2126
bvec a;
@@ -140,17 +145,19 @@ void test_argpass_last_p(PST *p) {
140145
// Not enough Z-regs, push PST to memory and pass a pointer, Z-regs and
141146
// P-regs still available for other arguments
142147
// u -> z0
143-
// 0.0 -> d1-d4
148+
// v -> q1
149+
// w -> q2
150+
// 0.0 -> d3-d4
144151
// 1 -> w0
145152
// *p -> memory, address -> x1
146153
// 2 -> w2
147154
// 3.0 -> d5
148155
// true -> p0
149-
void test_argpass_no_z(PST *p, double dummy, svmfloat8_t u) {
150-
void argpass_no_z_callee(svmfloat8_t, double, double, double, double, int, PST, int, double, svbool_t);
151-
argpass_no_z_callee(u, .0, .0, .0, .0, 1, *p, 2, 3.0, svptrue_b64());
156+
void test_argpass_no_z(PST *p, double dummy, svmfloat8_t u, int8x16_t v, mfloat8x16_t w) {
157+
void argpass_no_z_callee(svmfloat8_t, int8x16_t, mfloat8x16_t, double, double, int, PST, int, double, svbool_t);
158+
argpass_no_z_callee(u, v, w, .0, .0, 1, *p, 2, 3.0, svptrue_b64());
152159
}
153-
// CHECK: declare void @argpass_no_z_callee(<vscale x 16 x i8>, double noundef, double noundef, double noundef, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
160+
// CHECK: declare void @argpass_no_z_callee(<vscale x 16 x i8>, <16 x i8> noundef, <16 x i8>, double noundef, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
154161

155162

156163
// Like the above, using a tuple to occupy some registers.
@@ -200,6 +207,20 @@ void test_argpass_no_z_hfa(HFA *h, PST *p) {
200207
// CHECK-AAPCS: declare void @argpass_no_z_hfa_callee(double noundef, [4 x float] alignstack(8), i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
201208
// CHECK-DARWIN: declare void @argpass_no_z_hfa_callee(double noundef, [4 x float], i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
202209

210+
// Not enough Z-regs (consumed by a HVA), PST passed indirectly
211+
// 0.0 -> d0
212+
// *h -> s1-s4
213+
// 1 -> w0
214+
// *p -> memory, address -> x1
215+
// p -> x1
216+
// 2 -> w2
217+
// true -> p0
218+
void test_argpass_no_z_hva(HVA *h, PST *p) {
219+
void argpass_no_z_hva_callee(double, HVA, int, PST, int, svbool_t);
220+
argpass_no_z_hva_callee(.0, *h, 1, *p, 2, svptrue_b64());
221+
}
222+
// CHECK-AAPCS: declare void @argpass_no_z_hva_callee(double noundef, [4 x <16 x i8>] alignstack(16), i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
223+
// CHECK-DARWIN: declare void @argpass_no_z_hva_callee(double noundef, [4 x <16 x i8>], i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
203224

204225
// Not enough P-regs, PST passed indirectly, Z-regs and P-regs still available.
205226
// true -> p0-p2

0 commit comments

Comments
 (0)