Skip to content

[Clang][AArch64] Fix Pure Scalables Types argument passing and return #112747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 28, 2024

Conversation

momchil-velikov
Copy link
Collaborator

Pure Scalable Types are defined in AAPCS64 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the CoerceAndExpand method and extending it to:

  • allow array types in the coerceToType; Now only [N x i8] are considered padding.
  • allow mismatch between the elements of the coerceToType and the elements of the unpaddedCoerceToType; AArch64 uses this to map fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether to pass it in memory or registers or, equivalently, whether to use the Indirect or Expand/CoerceAndExpand method. It was considered relatively harder (or not practically possible) to make that decision in the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf. NSRN, NPRN) to guide the Clang's decision.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. labels Oct 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-codegen

Author: Momchil Velikov (momchil-velikov)

Changes

Pure Scalable Types are defined in AAPCS64 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the CoerceAndExpand method and extending it to:

  • allow array types in the coerceToType; Now only [N x i8] are considered padding.
  • allow mismatch between the elements of the coerceToType and the elements of the unpaddedCoerceToType; AArch64 uses this to map fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether to pass it in memory or registers or, equivalently, whether to use the Indirect or Expand/CoerceAndExpand method. It was considered relatively harder (or not practically possible) to make that decision in the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf. NSRN, NPRN) to guide the Clang's decision.


Patch is 42.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112747.diff

4 Files Affected:

  • (modified) clang/include/clang/CodeGen/CGFunctionInfo.h (+2-11)
  • (modified) clang/lib/CodeGen/CGCall.cpp (+64-21)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+289-37)
  • (added) clang/test/CodeGen/aarch64-pure-scalable-args.c (+314)
diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index d19f84d198876f..915f676d7d3905 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -272,11 +272,6 @@ class ABIArgInfo {
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
       if (isPaddingForCoerceAndExpand(eltType)) continue;
-      if (unpaddedStruct) {
-        assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
-      } else {
-        assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
-      }
       unpaddedIndex++;
     }
 
@@ -295,12 +290,8 @@ class ABIArgInfo {
   }
 
   static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
-    if (eltType->isArrayTy()) {
-      assert(eltType->getArrayElementType()->isIntegerTy(8));
-      return true;
-    } else {
-      return false;
-    }
+    return eltType->isArrayTy() &&
+           eltType->getArrayElementType()->isIntegerTy(8);
   }
 
   Kind getKind() const { return TheKind; }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..3c75dae9918af9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
   return addr;
 }
 
+static std::pair<llvm::Value *, bool>
+CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
+                      llvm::ScalableVectorType *FromTy, llvm::Value *V,
+                      StringRef Name = "") {
+  // If we are casting a scalable i1 predicate vector to a fixed i8
+  // vector, first bitcast the source.
+  if (FromTy->getElementType()->isIntegerTy(1) &&
+      FromTy->getElementCount().isKnownMultipleOf(8) &&
+      ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
+    FromTy = llvm::ScalableVectorType::get(
+        ToTy->getElementType(),
+        FromTy->getElementCount().getKnownMinValue() / 8);
+    V = CGF.Builder.CreateBitCast(V, FromTy);
+  }
+  if (FromTy->getElementType() == ToTy->getElementType()) {
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);
+
+    V->setName(Name + ".coerce");
+    V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
+    return {V, true};
+  }
+  return {V, false};
+}
+
 namespace {
 
 /// Encapsulates information about the way function arguments from
@@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       // a VLAT at the function boundary and the types match up, use
       // llvm.vector.extract to convert back to the original VLST.
       if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
-        llvm::Value *Coerced = Fn->getArg(FirstIRArg);
+        llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
         if (auto *VecTyFrom =
-                dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
-          // If we are casting a scalable i1 predicate vector to a fixed i8
-          // vector, bitcast the source and use a vector extract.
-          if (VecTyFrom->getElementType()->isIntegerTy(1) &&
-              VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
-              VecTyTo->getElementType() == Builder.getInt8Ty()) {
-            VecTyFrom = llvm::ScalableVectorType::get(
-                VecTyTo->getElementType(),
-                VecTyFrom->getElementCount().getKnownMinValue() / 8);
-            Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
-          }
-          if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
-            llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
-
+                dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
+          auto [Coerced, Extracted] = CoerceScalableToFixed(
+              *this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
+          if (Extracted) {
             assert(NumIRArgs == 1);
-            Coerced->setName(Arg->getName() + ".coerce");
-            ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
-                VecTyTo, Coerced, Zero, "cast.fixed")));
+            ArgVals.push_back(ParamValue::forDirect(Coerced));
             break;
           }
         }
@@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       ArgVals.push_back(ParamValue::forIndirect(alloca));
 
       auto coercionType = ArgI.getCoerceAndExpandType();
+      auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
+
       alloca = alloca.withElementType(coercionType);
 
       unsigned argIndex = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
           continue;
 
         auto eltAddr = Builder.CreateStructGEP(alloca, i);
-        auto elt = Fn->getArg(argIndex++);
+        llvm::Value *elt = Fn->getArg(argIndex++);
+
+        auto paramType = unpaddedStruct
+                             ? unpaddedStruct->getElementType(unpaddedIndex++)
+                             : unpaddedCoercionType;
+
+        if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
+          if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
+            bool Extracted;
+            std::tie(elt, Extracted) = CoerceScalableToFixed(
+                *this, VecTyTo, VecTyFrom, elt, elt->getName());
+            assert(Extracted && "Unexpected scalable to fixed vector coercion");
+          }
+        }
         Builder.CreateStore(elt, eltAddr);
       }
       assert(argIndex == FirstIRArg + NumIRArgs);
@@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
 
   case ABIArgInfo::CoerceAndExpand: {
     auto coercionType = RetAI.getCoerceAndExpandType();
+    auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
+    auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
     // Load all of the coerced elements out into results.
     llvm::SmallVector<llvm::Value*, 4> results;
     Address addr = ReturnValue.withElementType(coercionType);
+    unsigned unpaddedIndex = 0;
     for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
       auto coercedEltType = coercionType->getElementType(i);
       if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
         continue;
 
       auto eltAddr = Builder.CreateStructGEP(addr, i);
-      auto elt = Builder.CreateLoad(eltAddr);
+      llvm::Value *elt = CreateCoercedLoad(
+          eltAddr,
+          unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                         : unpaddedCoercionType,
+          *this);
       results.push_back(elt);
     }
 
@@ -5468,6 +5504,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
     case ABIArgInfo::CoerceAndExpand: {
       auto coercionType = ArgInfo.getCoerceAndExpandType();
       auto layout = CGM.getDataLayout().getStructLayout(coercionType);
+      auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
       llvm::Value *tempSize = nullptr;
       Address addr = Address::invalid();
@@ -5498,11 +5536,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
       addr = addr.withElementType(coercionType);
 
       unsigned IRArgPos = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
         Address eltAddr = Builder.CreateStructGEP(addr, i);
-        llvm::Value *elt = Builder.CreateLoad(eltAddr);
+        llvm::Value *elt = CreateCoercedLoad(
+            eltAddr,
+            unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                           : unpaddedCoercionType,
+            *this);
         if (ArgHasMaybeUndefAttr)
           elt = Builder.CreateFreeze(elt);
         IRCallArgs[IRArgPos++] = elt;
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ec617eec67192c..269b5b352bfd84 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -36,8 +36,15 @@ class AArch64ABIInfo : public ABIInfo {
 
   ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadic) const;
   ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadic,
-                                  unsigned CallingConvention) const;
-  ABIArgInfo coerceIllegalVector(QualType Ty) const;
+                                  unsigned CallingConvention, unsigned &NSRN,
+                                  unsigned &NPRN) const;
+  llvm::Type *convertFixedToScalableVectorType(const VectorType *VT) const;
+  ABIArgInfo coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                 unsigned &NPRN) const;
+  ABIArgInfo coerceAndExpandPureScalableAggregate(
+      QualType Ty, unsigned NVec, unsigned NPred,
+      const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+      unsigned &NPRN) const;
   bool isHomogeneousAggregateBaseType(QualType Ty) const override;
   bool isHomogeneousAggregateSmallEnough(const Type *Ty,
                                          uint64_t Members) const override;
@@ -45,14 +52,21 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
+  bool isPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                          SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+
+  void flattenType(llvm::Type *Ty,
+                   SmallVectorImpl<llvm::Type *> &Flattened) const;
+
   void computeInfo(CGFunctionInfo &FI) const override {
     if (!::classifyReturnType(getCXXABI(), FI, *this))
       FI.getReturnInfo() =
           classifyReturnType(FI.getReturnType(), FI.isVariadic());
 
+    unsigned NSRN = 0, NPRN = 0;
     for (auto &it : FI.arguments())
       it.info = classifyArgumentType(it.type, FI.isVariadic(),
-                                     FI.getCallingConvention());
+                                     FI.getCallingConvention(), NSRN, NPRN);
   }
 
   RValue EmitDarwinVAArg(Address VAListAddr, QualType Ty, CodeGenFunction &CGF,
@@ -201,65 +215,83 @@ void WindowsAArch64TargetCodeGenInfo::setTargetAttributes(
 }
 }
 
-ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
-  assert(Ty->isVectorType() && "expected vector type!");
+llvm::Type *
+AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
+  assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  const auto *VT = Ty->castAs<VectorType>();
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
     assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
                BuiltinType::UChar &&
            "unexpected builtin type for SVE predicate!");
-    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
-        llvm::Type::getInt1Ty(getVMContext()), 16));
+    return llvm::ScalableVectorType::get(llvm::Type::getInt1Ty(getVMContext()),
+                                         16);
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
-
     const auto *BT = VT->getElementType()->castAs<BuiltinType>();
-    llvm::ScalableVectorType *ResType = nullptr;
     switch (BT->getKind()) {
     default:
       llvm_unreachable("unexpected builtin type for SVE vector!");
+
     case BuiltinType::SChar:
     case BuiltinType::UChar:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
-      break;
+
     case BuiltinType::Short:
     case BuiltinType::UShort:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt16Ty(getVMContext()), 8);
-      break;
+
     case BuiltinType::Int:
     case BuiltinType::UInt:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt32Ty(getVMContext()), 4);
-      break;
+
     case BuiltinType::Long:
     case BuiltinType::ULong:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt64Ty(getVMContext()), 2);
-      break;
+
     case BuiltinType::Half:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getHalfTy(getVMContext()), 8);
-      break;
+
     case BuiltinType::Float:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getFloatTy(getVMContext()), 4);
-      break;
+
     case BuiltinType::Double:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getDoubleTy(getVMContext()), 2);
-      break;
+
     case BuiltinType::BFloat16:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getBFloatTy(getVMContext()), 8);
-      break;
     }
-    return ABIArgInfo::getDirect(ResType);
+  }
+
+  llvm_unreachable("expected fixed-length SVE vector");
+}
+
+ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                               unsigned &NPRN) const {
+  assert(Ty->isVectorType() && "expected vector type!");
+
+  const auto *VT = Ty->castAs<VectorType>();
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
+    assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
+               BuiltinType::UChar &&
+           "unexpected builtin type for SVE predicate!");
+    NPRN = std::min(NPRN + 1, 4u);
+    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
+        llvm::Type::getInt1Ty(getVMContext()), 16));
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    NSRN = std::min(NSRN + 1, 8u);
+    return ABIArgInfo::getDirect(convertFixedToScalableVectorType(VT));
   }
 
   uint64_t Size = getContext().getTypeSize(Ty);
@@ -273,26 +305,53 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 64) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 2);
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 128) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 4);
     return ABIArgInfo::getDirect(ResType);
   }
+
   return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
 }
 
-ABIArgInfo
-AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
-                                     unsigned CallingConvention) const {
+ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
+    QualType Ty, unsigned NVec, unsigned NPred,
+    const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+    unsigned &NPRN) const {
+  if (NSRN + NVec > 8 || NPRN + NPred > 4)
+    return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
+  NSRN += NVec;
+  NPRN += NPred;
+
+  llvm::Type *UnpaddedCoerceToType =
+      UnpaddedCoerceToSeq.size() == 1
+          ? UnpaddedCoerceToSeq[0]
+          : llvm::StructType::get(CGT.getLLVMContext(), UnpaddedCoerceToSeq,
+                                  true);
+
+  SmallVector<llvm::Type *> CoerceToSeq;
+  flattenType(CGT.ConvertType(Ty), CoerceToSeq);
+  auto *CoerceToType =
+      llvm::StructType::get(CGT.getLLVMContext(), CoerceToSeq, false);
+
+  return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
+}
+
+ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
+                                                unsigned CallingConvention,
+                                                unsigned &NSRN,
+                                                unsigned &NPRN) const {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   // Handle illegal vector types here.
   if (isIllegalVectorType(Ty))
-    return coerceIllegalVector(Ty);
+    return coerceIllegalVector(Ty, NSRN, NPRN);
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
@@ -303,6 +362,20 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
       if (EIT->getNumBits() > 128)
         return getNaturalAlignIndirect(Ty, false);
 
+    if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
+      if (BT->isSVEBool() || BT->isSVECount())
+        NPRN = std::min(NPRN + 1, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx2)
+        NPRN = std::min(NPRN + 2, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx4)
+        NPRN = std::min(NPRN + 4, 4u);
+      else if (BT->isFloatingPoint() || BT->isVectorType())
+        NSRN = std::min(NSRN + 1, 8u);
+      else if (BT->isSVESizelessBuiltinType())
+        NSRN = std::min(
+            NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors, 8u);
+    }
+
     return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
                 ? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
                 : ABIArgInfo::getDirect());
@@ -339,6 +412,7 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   // In variadic functions on Windows, all composite types are treated alike,
   // no special handling of HFAs/HVAs.
   if (!IsWinVariadic && isHomogeneousAggregate(Ty, Base, Members)) {
+    NSRN = std::min(NSRN + Members, uint64_t(8));
     if (Kind != AArch64ABIKind::AAPCS)
       return ABIArgInfo::getDirect(
           llvm::ArrayType::get(CGT.ConvertType(QualType(Base, 0)), Members));
@@ -353,6 +427,17 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
         nullptr, true, Align);
   }
 
+  // In AAPCS named arguments of a Pure Scalable Type are passed expanded in
+  // registers, or indirectly if there are not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS && !IsVariadic) {
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          Ty, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are passed directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -389,8 +474,10 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (const auto *VT = RetTy->getAs<VectorType>()) {
     if (VT->getVectorKind() == VectorKind::SveFixedLengthData ||
-        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
-      return coerceIllegalVector(RetTy);
+        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+      unsigned NSRN = 0, NPRN = 0;
+      return coerceIllegalVector(RetTy, NSRN, NPRN);
+    }
   }
 
   // Large vector types should be returned via memory.
@@ -423,6 +510,19 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
+  // In AAPCS return values of a Pure Scalable type are treated is a first named
+  // argument and passed expanded in registers, or indirectly if there are not
+  // enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NSRN = 0, NPRN = 0;
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

Changes

Pure Scalable Types are defined in AAPCS64 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the CoerceAndExpand method and extending it to:

  • allow array types in the coerceToType; Now only [N x i8] are considered padding.
  • allow mismatch between the elements of the coerceToType and the elements of the unpaddedCoerceToType; AArch64 uses this to map fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether to pass it in memory or registers or, equivalently, whether to use the Indirect or Expand/CoerceAndExpand method. It was considered relatively harder (or not practically possible) to make that decision in the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf. NSRN, NPRN) to guide the Clang's decision.


Patch is 42.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112747.diff

4 Files Affected:

  • (modified) clang/include/clang/CodeGen/CGFunctionInfo.h (+2-11)
  • (modified) clang/lib/CodeGen/CGCall.cpp (+64-21)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+289-37)
  • (added) clang/test/CodeGen/aarch64-pure-scalable-args.c (+314)
diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index d19f84d198876f..915f676d7d3905 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -272,11 +272,6 @@ class ABIArgInfo {
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
       if (isPaddingForCoerceAndExpand(eltType)) continue;
-      if (unpaddedStruct) {
-        assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
-      } else {
-        assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
-      }
       unpaddedIndex++;
     }
 
@@ -295,12 +290,8 @@ class ABIArgInfo {
   }
 
   static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
-    if (eltType->isArrayTy()) {
-      assert(eltType->getArrayElementType()->isIntegerTy(8));
-      return true;
-    } else {
-      return false;
-    }
+    return eltType->isArrayTy() &&
+           eltType->getArrayElementType()->isIntegerTy(8);
   }
 
   Kind getKind() const { return TheKind; }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..3c75dae9918af9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
   return addr;
 }
 
+static std::pair<llvm::Value *, bool>
+CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
+                      llvm::ScalableVectorType *FromTy, llvm::Value *V,
+                      StringRef Name = "") {
+  // If we are casting a scalable i1 predicate vector to a fixed i8
+  // vector, first bitcast the source.
+  if (FromTy->getElementType()->isIntegerTy(1) &&
+      FromTy->getElementCount().isKnownMultipleOf(8) &&
+      ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
+    FromTy = llvm::ScalableVectorType::get(
+        ToTy->getElementType(),
+        FromTy->getElementCount().getKnownMinValue() / 8);
+    V = CGF.Builder.CreateBitCast(V, FromTy);
+  }
+  if (FromTy->getElementType() == ToTy->getElementType()) {
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);
+
+    V->setName(Name + ".coerce");
+    V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
+    return {V, true};
+  }
+  return {V, false};
+}
+
 namespace {
 
 /// Encapsulates information about the way function arguments from
@@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       // a VLAT at the function boundary and the types match up, use
       // llvm.vector.extract to convert back to the original VLST.
       if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
-        llvm::Value *Coerced = Fn->getArg(FirstIRArg);
+        llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
         if (auto *VecTyFrom =
-                dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
-          // If we are casting a scalable i1 predicate vector to a fixed i8
-          // vector, bitcast the source and use a vector extract.
-          if (VecTyFrom->getElementType()->isIntegerTy(1) &&
-              VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
-              VecTyTo->getElementType() == Builder.getInt8Ty()) {
-            VecTyFrom = llvm::ScalableVectorType::get(
-                VecTyTo->getElementType(),
-                VecTyFrom->getElementCount().getKnownMinValue() / 8);
-            Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
-          }
-          if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
-            llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
-
+                dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
+          auto [Coerced, Extracted] = CoerceScalableToFixed(
+              *this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
+          if (Extracted) {
             assert(NumIRArgs == 1);
-            Coerced->setName(Arg->getName() + ".coerce");
-            ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
-                VecTyTo, Coerced, Zero, "cast.fixed")));
+            ArgVals.push_back(ParamValue::forDirect(Coerced));
             break;
           }
         }
@@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       ArgVals.push_back(ParamValue::forIndirect(alloca));
 
       auto coercionType = ArgI.getCoerceAndExpandType();
+      auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
+
       alloca = alloca.withElementType(coercionType);
 
       unsigned argIndex = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
           continue;
 
         auto eltAddr = Builder.CreateStructGEP(alloca, i);
-        auto elt = Fn->getArg(argIndex++);
+        llvm::Value *elt = Fn->getArg(argIndex++);
+
+        auto paramType = unpaddedStruct
+                             ? unpaddedStruct->getElementType(unpaddedIndex++)
+                             : unpaddedCoercionType;
+
+        if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
+          if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
+            bool Extracted;
+            std::tie(elt, Extracted) = CoerceScalableToFixed(
+                *this, VecTyTo, VecTyFrom, elt, elt->getName());
+            assert(Extracted && "Unexpected scalable to fixed vector coercion");
+          }
+        }
         Builder.CreateStore(elt, eltAddr);
       }
       assert(argIndex == FirstIRArg + NumIRArgs);
@@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
 
   case ABIArgInfo::CoerceAndExpand: {
     auto coercionType = RetAI.getCoerceAndExpandType();
+    auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
+    auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
     // Load all of the coerced elements out into results.
     llvm::SmallVector<llvm::Value*, 4> results;
     Address addr = ReturnValue.withElementType(coercionType);
+    unsigned unpaddedIndex = 0;
     for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
       auto coercedEltType = coercionType->getElementType(i);
       if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
         continue;
 
       auto eltAddr = Builder.CreateStructGEP(addr, i);
-      auto elt = Builder.CreateLoad(eltAddr);
+      llvm::Value *elt = CreateCoercedLoad(
+          eltAddr,
+          unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                         : unpaddedCoercionType,
+          *this);
       results.push_back(elt);
     }
 
@@ -5468,6 +5504,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
     case ABIArgInfo::CoerceAndExpand: {
       auto coercionType = ArgInfo.getCoerceAndExpandType();
       auto layout = CGM.getDataLayout().getStructLayout(coercionType);
+      auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
       llvm::Value *tempSize = nullptr;
       Address addr = Address::invalid();
@@ -5498,11 +5536,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
       addr = addr.withElementType(coercionType);
 
       unsigned IRArgPos = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
         Address eltAddr = Builder.CreateStructGEP(addr, i);
-        llvm::Value *elt = Builder.CreateLoad(eltAddr);
+        llvm::Value *elt = CreateCoercedLoad(
+            eltAddr,
+            unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                           : unpaddedCoercionType,
+            *this);
         if (ArgHasMaybeUndefAttr)
           elt = Builder.CreateFreeze(elt);
         IRCallArgs[IRArgPos++] = elt;
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ec617eec67192c..269b5b352bfd84 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -36,8 +36,15 @@ class AArch64ABIInfo : public ABIInfo {
 
   ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadic) const;
   ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadic,
-                                  unsigned CallingConvention) const;
-  ABIArgInfo coerceIllegalVector(QualType Ty) const;
+                                  unsigned CallingConvention, unsigned &NSRN,
+                                  unsigned &NPRN) const;
+  llvm::Type *convertFixedToScalableVectorType(const VectorType *VT) const;
+  ABIArgInfo coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                 unsigned &NPRN) const;
+  ABIArgInfo coerceAndExpandPureScalableAggregate(
+      QualType Ty, unsigned NVec, unsigned NPred,
+      const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+      unsigned &NPRN) const;
   bool isHomogeneousAggregateBaseType(QualType Ty) const override;
   bool isHomogeneousAggregateSmallEnough(const Type *Ty,
                                          uint64_t Members) const override;
@@ -45,14 +52,21 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
+  bool isPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                          SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+
+  void flattenType(llvm::Type *Ty,
+                   SmallVectorImpl<llvm::Type *> &Flattened) const;
+
   void computeInfo(CGFunctionInfo &FI) const override {
     if (!::classifyReturnType(getCXXABI(), FI, *this))
       FI.getReturnInfo() =
           classifyReturnType(FI.getReturnType(), FI.isVariadic());
 
+    unsigned NSRN = 0, NPRN = 0;
     for (auto &it : FI.arguments())
       it.info = classifyArgumentType(it.type, FI.isVariadic(),
-                                     FI.getCallingConvention());
+                                     FI.getCallingConvention(), NSRN, NPRN);
   }
 
   RValue EmitDarwinVAArg(Address VAListAddr, QualType Ty, CodeGenFunction &CGF,
@@ -201,65 +215,83 @@ void WindowsAArch64TargetCodeGenInfo::setTargetAttributes(
 }
 }
 
-ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
-  assert(Ty->isVectorType() && "expected vector type!");
+llvm::Type *
+AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
+  assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  const auto *VT = Ty->castAs<VectorType>();
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
     assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
                BuiltinType::UChar &&
            "unexpected builtin type for SVE predicate!");
-    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
-        llvm::Type::getInt1Ty(getVMContext()), 16));
+    return llvm::ScalableVectorType::get(llvm::Type::getInt1Ty(getVMContext()),
+                                         16);
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
-
     const auto *BT = VT->getElementType()->castAs<BuiltinType>();
-    llvm::ScalableVectorType *ResType = nullptr;
     switch (BT->getKind()) {
     default:
       llvm_unreachable("unexpected builtin type for SVE vector!");
+
     case BuiltinType::SChar:
     case BuiltinType::UChar:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
-      break;
+
     case BuiltinType::Short:
     case BuiltinType::UShort:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt16Ty(getVMContext()), 8);
-      break;
+
     case BuiltinType::Int:
     case BuiltinType::UInt:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt32Ty(getVMContext()), 4);
-      break;
+
     case BuiltinType::Long:
     case BuiltinType::ULong:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt64Ty(getVMContext()), 2);
-      break;
+
     case BuiltinType::Half:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getHalfTy(getVMContext()), 8);
-      break;
+
     case BuiltinType::Float:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getFloatTy(getVMContext()), 4);
-      break;
+
     case BuiltinType::Double:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getDoubleTy(getVMContext()), 2);
-      break;
+
     case BuiltinType::BFloat16:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getBFloatTy(getVMContext()), 8);
-      break;
     }
-    return ABIArgInfo::getDirect(ResType);
+  }
+
+  llvm_unreachable("expected fixed-length SVE vector");
+}
+
+ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                               unsigned &NPRN) const {
+  assert(Ty->isVectorType() && "expected vector type!");
+
+  const auto *VT = Ty->castAs<VectorType>();
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
+    assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
+               BuiltinType::UChar &&
+           "unexpected builtin type for SVE predicate!");
+    NPRN = std::min(NPRN + 1, 4u);
+    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
+        llvm::Type::getInt1Ty(getVMContext()), 16));
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    NSRN = std::min(NSRN + 1, 8u);
+    return ABIArgInfo::getDirect(convertFixedToScalableVectorType(VT));
   }
 
   uint64_t Size = getContext().getTypeSize(Ty);
@@ -273,26 +305,53 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 64) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 2);
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 128) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 4);
     return ABIArgInfo::getDirect(ResType);
   }
+
   return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
 }
 
-ABIArgInfo
-AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
-                                     unsigned CallingConvention) const {
+ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
+    QualType Ty, unsigned NVec, unsigned NPred,
+    const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+    unsigned &NPRN) const {
+  if (NSRN + NVec > 8 || NPRN + NPred > 4)
+    return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
+  NSRN += NVec;
+  NPRN += NPred;
+
+  llvm::Type *UnpaddedCoerceToType =
+      UnpaddedCoerceToSeq.size() == 1
+          ? UnpaddedCoerceToSeq[0]
+          : llvm::StructType::get(CGT.getLLVMContext(), UnpaddedCoerceToSeq,
+                                  true);
+
+  SmallVector<llvm::Type *> CoerceToSeq;
+  flattenType(CGT.ConvertType(Ty), CoerceToSeq);
+  auto *CoerceToType =
+      llvm::StructType::get(CGT.getLLVMContext(), CoerceToSeq, false);
+
+  return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
+}
+
+ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
+                                                unsigned CallingConvention,
+                                                unsigned &NSRN,
+                                                unsigned &NPRN) const {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   // Handle illegal vector types here.
   if (isIllegalVectorType(Ty))
-    return coerceIllegalVector(Ty);
+    return coerceIllegalVector(Ty, NSRN, NPRN);
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
@@ -303,6 +362,20 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
       if (EIT->getNumBits() > 128)
         return getNaturalAlignIndirect(Ty, false);
 
+    if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
+      if (BT->isSVEBool() || BT->isSVECount())
+        NPRN = std::min(NPRN + 1, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx2)
+        NPRN = std::min(NPRN + 2, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx4)
+        NPRN = std::min(NPRN + 4, 4u);
+      else if (BT->isFloatingPoint() || BT->isVectorType())
+        NSRN = std::min(NSRN + 1, 8u);
+      else if (BT->isSVESizelessBuiltinType())
+        NSRN = std::min(
+            NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors, 8u);
+    }
+
     return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
                 ? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
                 : ABIArgInfo::getDirect());
@@ -339,6 +412,7 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   // In variadic functions on Windows, all composite types are treated alike,
   // no special handling of HFAs/HVAs.
   if (!IsWinVariadic && isHomogeneousAggregate(Ty, Base, Members)) {
+    NSRN = std::min(NSRN + Members, uint64_t(8));
     if (Kind != AArch64ABIKind::AAPCS)
       return ABIArgInfo::getDirect(
           llvm::ArrayType::get(CGT.ConvertType(QualType(Base, 0)), Members));
@@ -353,6 +427,17 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
         nullptr, true, Align);
   }
 
+  // In AAPCS named arguments of a Pure Scalable Type are passed expanded in
+  // registers, or indirectly if there are not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS && !IsVariadic) {
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          Ty, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are passed directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -389,8 +474,10 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (const auto *VT = RetTy->getAs<VectorType>()) {
     if (VT->getVectorKind() == VectorKind::SveFixedLengthData ||
-        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
-      return coerceIllegalVector(RetTy);
+        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+      unsigned NSRN = 0, NPRN = 0;
+      return coerceIllegalVector(RetTy, NSRN, NPRN);
+    }
   }
 
   // Large vector types should be returned via memory.
@@ -423,6 +510,19 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
+  // In AAPCS return values of a Pure Scalable type are treated is a first named
+  // argument and passed expanded in registers, or indirectly if there are not
+  // enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NSRN = 0, NPRN = 0;
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return...
[truncated]

Copy link

github-actions bot commented Oct 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang tracks the number of registers used on other targets, so it's not unprecedented... just an unfortunate consequence of the limitations of ABI in LLVM IR.

Pure Scalable Types are defined in AAPCS64 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead
it treats PSTs sometimes as HFAs/HVAs, sometime as general composite
types.

This patch implements the rules for passing PSTs by employing the
`CoerceAndExpand` method and extending it to:
  * allow array types in the `coerceToType`; Now only `[N x i8]` are
    considered padding.
  * allow mismatch between the elements of the `coerceToType` and the
    elements of the `unpaddedCoerceToType`; AArch64 uses this to map
    fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about
whether to pass it in memory or registers or, equivalently, whether
to use the `Indirect` or `Expand/CoerceAndExpand` method.
It was considered relatively harder (or not practically possible)
to make that decision in the AArch64 backend.
Hence this patch implements the register counting
from AAPCS64 (cf. "NSRN", "NPRN") to guide the Clang's decision.
... even though the AArch64 backend does not support it (currently).
* fix incorrect determination whether an argument is named
* fix where a small (<= 16 bytes) unnamed PST argument
   was passed directly
* addes tests accordingly
* correctly count SIMD vectors
* handle the new mfloat8x8_t and mfloat8x16_t types

and tests updated accordingly.
Copy link
Contributor

@jthackray jthackray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@momchil-velikov
Copy link
Collaborator Author

Thanks for the reviews, much appreciated!

@momchil-velikov momchil-velikov merged commit 53f7f8e into llvm:main Oct 28, 2024
8 checks passed
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
…llvm#112747)

Pure Scalable Types are defined in AAPCS64 here:

https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:

https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it
treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the
`CoerceAndExpand` method and extending it to:
* allow array types in the `coerceToType`; Now only `[N x i8]` are
considered padding.
* allow mismatch between the elements of the `coerceToType` and the
elements of the `unpaddedCoerceToType`; AArch64 uses this to map
fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether
to pass it in memory or registers or, equivalently, whether to use the
`Indirect` or `Expand/CoerceAndExpand` method. It was considered
relatively harder (or not practically possible) to make that decision in
the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf.
`NSRN`, `NPRN`) to guide the Clang's decision.
@momchil-velikov momchil-velikov deleted the pure-scalable-args branch November 13, 2024 09:32
momchil-velikov added a commit that referenced this pull request Dec 23, 2024
The fix for passing Pure Scalable Types
(#112747) was incomplete,
it didn't handle correctly tuples of SVE vectors (e.g. `sveboolx2_t`,
`svfloat32x4_t`, etc).

These types are Pure Scalable Types and should be passed either entirely
in vector registers
or indirectly in memory, not split.
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 10, 2025
The fix for passing Pure Scalable Types
(llvm/llvm-project#112747) was incomplete,
it didn't handle correctly tuples of SVE vectors (e.g. `sveboolx2_t`,
`svfloat32x4_t`, etc).

These types are Pure Scalable Types and should be passed either entirely
in vector registers
or indirectly in memory, not split.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants