Skip to content

Commit 83331bb

Browse files
Fix SVE tuples
1 parent 7e2d603 commit 83331bb

File tree

3 files changed

+111
-82
lines changed

3 files changed

+111
-82
lines changed

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class AArch64ABIInfo : public ABIInfo {
5252

5353
bool isIllegalVectorType(QualType Ty) const;
5454

55+
bool passAsAggregateType(QualType Ty) const;
5556
bool passAsPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
5657
SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
5758

@@ -337,6 +338,10 @@ ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
337338
NSRN += NVec;
338339
NPRN += NPred;
339340

341+
// Handle SVE vector tuples.
342+
if (Ty->isSVESizelessBuiltinType())
343+
return ABIArgInfo::getDirect();
344+
340345
llvm::Type *UnpaddedCoerceToType =
341346
UnpaddedCoerceToSeq.size() == 1
342347
? UnpaddedCoerceToSeq[0]
@@ -362,7 +367,7 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
362367
if (isIllegalVectorType(Ty))
363368
return coerceIllegalVector(Ty, NSRN, NPRN);
364369

365-
if (!isAggregateTypeForABI(Ty)) {
370+
if (!passAsAggregateType(Ty)) {
366371
// Treat an enum type as its underlying type.
367372
if (const EnumType *EnumTy = Ty->getAs<EnumType>())
368373
Ty = EnumTy->getDecl()->getIntegerType();
@@ -417,7 +422,7 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
417422
// elsewhere for GNU compatibility.
418423
uint64_t Size = getContext().getTypeSize(Ty);
419424
bool IsEmpty = isEmptyRecord(getContext(), Ty, true);
420-
if (IsEmpty || Size == 0) {
425+
if (!Ty->isSVESizelessBuiltinType() && (IsEmpty || Size == 0)) {
421426
if (!getContext().getLangOpts().CPlusPlus || isDarwinPCS())
422427
return ABIArgInfo::getIgnore();
423428

@@ -504,7 +509,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
504509
if (RetTy->isVectorType() && getContext().getTypeSize(RetTy) > 128)
505510
return getNaturalAlignIndirect(RetTy);
506511

507-
if (!isAggregateTypeForABI(RetTy)) {
512+
if (!passAsAggregateType(RetTy)) {
508513
// Treat an enum type as its underlying type.
509514
if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
510515
RetTy = EnumTy->getDecl()->getIntegerType();
@@ -519,7 +524,8 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
519524
}
520525

521526
uint64_t Size = getContext().getTypeSize(RetTy);
522-
if (isEmptyRecord(getContext(), RetTy, true) || Size == 0)
527+
if (!RetTy->isSVESizelessBuiltinType() &&
528+
(isEmptyRecord(getContext(), RetTy, true) || Size == 0))
523529
return ABIArgInfo::getIgnore();
524530

525531
const Type *Base = nullptr;
@@ -654,6 +660,15 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
654660
return true;
655661
}
656662

663+
bool AArch64ABIInfo::passAsAggregateType(QualType Ty) const {
664+
if (Kind == AArch64ABIKind::AAPCS && Ty->isSVESizelessBuiltinType()) {
665+
const auto *BT = Ty->getAs<BuiltinType>();
666+
return !BT->isSVECount() &&
667+
getContext().getBuiltinVectorTypeInfo(BT).NumVectors > 1;
668+
}
669+
return isAggregateTypeForABI(Ty);
670+
}
671+
657672
// Check if a type needs to be passed in registers as a Pure Scalable Type (as
658673
// defined by AAPCS64). Return the number of data vectors and the number of
659674
// predicate vectors in the type, into `NVec` and `NPred`, respectively. Upon
@@ -719,37 +734,38 @@ bool AArch64ABIInfo::passAsPureScalableType(
719734
return true;
720735
}
721736

722-
const auto *VT = Ty->getAs<VectorType>();
723-
if (!VT)
724-
return false;
737+
if (const auto *VT = Ty->getAs<VectorType>()) {
738+
if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
739+
++NPred;
740+
if (CoerceToSeq.size() + 1 > 12)
741+
return false;
742+
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
743+
return true;
744+
}
725745

726-
if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
727-
++NPred;
728-
if (CoerceToSeq.size() + 1 > 12)
729-
return false;
730-
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
731-
return true;
732-
}
746+
if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
747+
++NVec;
748+
if (CoerceToSeq.size() + 1 > 12)
749+
return false;
750+
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
751+
return true;
752+
}
733753

734-
if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
735-
++NVec;
736-
if (CoerceToSeq.size() + 1 > 12)
737-
return false;
738-
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
739-
return true;
754+
return false;
740755
}
741756

742-
if (!VT->isBuiltinType())
757+
if (!Ty->isBuiltinType())
743758
return false;
744759

745-
switch (cast<BuiltinType>(VT)->getKind()) {
760+
bool isPredicate;
761+
switch (Ty->getAs<BuiltinType>()->getKind()) {
746762
#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
747763
case BuiltinType::Id: \
748-
++NVec; \
764+
isPredicate = false; \
749765
break;
750766
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
751767
case BuiltinType::Id: \
752-
++NPred; \
768+
isPredicate = true; \
753769
break;
754770
#define SVE_TYPE(Name, Id, SingletonId)
755771
#include "clang/Basic/AArch64SVEACLETypes.def"
@@ -761,6 +777,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
761777
getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
762778
assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
763779
"Expected 1, 2, 3 or 4 vectors!");
780+
if (isPredicate)
781+
NPred += Info.NumVectors;
782+
else
783+
NVec += Info.NumVectors;
764784
auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
765785
Info.EC.getKnownMinValue());
766786

clang/test/CodeGen/AArch64/pure-scalable-args.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,22 @@ void test_va_arg(int n, ...) {
459459
// CHECK-DARWIN-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ap)
460460
// CHECK-DARWIN-NEXT: ret void
461461
// CHECK-DARWIN-NEXT: }
462+
463+
// Regression test for incorrect passing of SVE vector tuples
464+
// The whole `y` need to be passed indirectly.
465+
void test_tuple_reg_count(svfloat32_t x, svfloat32x2_t y) {
466+
void test_tuple_reg_count_callee(svfloat32_t, svfloat32_t, svfloat32_t, svfloat32_t,
467+
svfloat32_t, svfloat32_t, svfloat32_t, svfloat32x2_t);
468+
test_tuple_reg_count_callee(x, x, x, x, x, x, x, y);
469+
}
470+
// CHECK-AAPCS: declare void @test_tuple_reg_count_callee(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, ptr noundef)
471+
// CHECK-DARWIN: declare void @test_tuple_reg_count_callee(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
472+
473+
// Regression test for incorrect passing of SVE vector tuples
474+
// The whole `y` need to be passed indirectly.
475+
void test_tuple_reg_count_bool(svboolx4_t x, svboolx4_t y) {
476+
void test_tuple_reg_count_bool_callee(svboolx4_t, svboolx4_t);
477+
test_tuple_reg_count_bool_callee(x, y);
478+
}
479+
// CHECK-AAPCS: declare void @test_tuple_reg_count_bool_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, ptr noundef)
480+
// CHECK-DARWIN: declare void @test_tuple_reg_count_bool_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>)

0 commit comments

Comments
 (0)