Skip to content

[Clang][AArch64] Fix Pure Scalables Types argument passing and return #112747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions clang/include/clang/CodeGen/CGFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,8 @@ class ABIArgInfo {
// in the unpadded type.
unsigned unpaddedIndex = 0;
for (auto eltType : coerceToType->elements()) {
if (isPaddingForCoerceAndExpand(eltType)) continue;
if (unpaddedStruct) {
assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
} else {
assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
}
if (isPaddingForCoerceAndExpand(eltType))
continue;
unpaddedIndex++;
}

Expand All @@ -295,12 +291,8 @@ class ABIArgInfo {
}

static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
if (eltType->isArrayTy()) {
assert(eltType->getArrayElementType()->isIntegerTy(8));
return true;
} else {
return false;
}
return eltType->isArrayTy() &&
eltType->getArrayElementType()->isIntegerTy(8);
}

Kind getKind() const { return TheKind; }
Expand Down
85 changes: 64 additions & 21 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
return addr;
}

static std::pair<llvm::Value *, bool>
CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
llvm::ScalableVectorType *FromTy, llvm::Value *V,
StringRef Name = "") {
// If we are casting a scalable i1 predicate vector to a fixed i8
// vector, first bitcast the source.
if (FromTy->getElementType()->isIntegerTy(1) &&
FromTy->getElementCount().isKnownMultipleOf(8) &&
ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
FromTy = llvm::ScalableVectorType::get(
ToTy->getElementType(),
FromTy->getElementCount().getKnownMinValue() / 8);
V = CGF.Builder.CreateBitCast(V, FromTy);
}
if (FromTy->getElementType() == ToTy->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);

V->setName(Name + ".coerce");
V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
return {V, true};
}
return {V, false};
}

namespace {

/// Encapsulates information about the way function arguments from
Expand Down Expand Up @@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
// a VLAT at the function boundary and the types match up, use
// llvm.vector.extract to convert back to the original VLST.
if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
llvm::Value *Coerced = Fn->getArg(FirstIRArg);
llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
if (auto *VecTyFrom =
dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
// If we are casting a scalable i1 predicate vector to a fixed i8
// vector, bitcast the source and use a vector extract.
if (VecTyFrom->getElementType()->isIntegerTy(1) &&
VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
VecTyTo->getElementType() == Builder.getInt8Ty()) {
VecTyFrom = llvm::ScalableVectorType::get(
VecTyTo->getElementType(),
VecTyFrom->getElementCount().getKnownMinValue() / 8);
Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
}
if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);

dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
auto [Coerced, Extracted] = CoerceScalableToFixed(
*this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
if (Extracted) {
assert(NumIRArgs == 1);
Coerced->setName(Arg->getName() + ".coerce");
ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
VecTyTo, Coerced, Zero, "cast.fixed")));
ArgVals.push_back(ParamValue::forDirect(Coerced));
break;
}
}
Expand Down Expand Up @@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
ArgVals.push_back(ParamValue::forIndirect(alloca));

auto coercionType = ArgI.getCoerceAndExpandType();
auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);

alloca = alloca.withElementType(coercionType);

unsigned argIndex = FirstIRArg;
unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
continue;

auto eltAddr = Builder.CreateStructGEP(alloca, i);
auto elt = Fn->getArg(argIndex++);
llvm::Value *elt = Fn->getArg(argIndex++);

auto paramType = unpaddedStruct
? unpaddedStruct->getElementType(unpaddedIndex++)
: unpaddedCoercionType;

if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
bool Extracted;
std::tie(elt, Extracted) = CoerceScalableToFixed(
*this, VecTyTo, VecTyFrom, elt, elt->getName());
assert(Extracted && "Unexpected scalable to fixed vector coercion");
}
}
Builder.CreateStore(elt, eltAddr);
}
assert(argIndex == FirstIRArg + NumIRArgs);
Expand Down Expand Up @@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,

case ABIArgInfo::CoerceAndExpand: {
auto coercionType = RetAI.getCoerceAndExpandType();
auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);

// Load all of the coerced elements out into results.
llvm::SmallVector<llvm::Value*, 4> results;
Address addr = ReturnValue.withElementType(coercionType);
unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
auto coercedEltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
continue;

auto eltAddr = Builder.CreateStructGEP(addr, i);
auto elt = Builder.CreateLoad(eltAddr);
llvm::Value *elt = CreateCoercedLoad(
eltAddr,
unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
: unpaddedCoercionType,
*this);
results.push_back(elt);
}

Expand Down Expand Up @@ -5468,6 +5504,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
case ABIArgInfo::CoerceAndExpand: {
auto coercionType = ArgInfo.getCoerceAndExpandType();
auto layout = CGM.getDataLayout().getStructLayout(coercionType);
auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);

llvm::Value *tempSize = nullptr;
Address addr = Address::invalid();
Expand Down Expand Up @@ -5498,11 +5536,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
addr = addr.withElementType(coercionType);

unsigned IRArgPos = FirstIRArg;
unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
Address eltAddr = Builder.CreateStructGEP(addr, i);
llvm::Value *elt = Builder.CreateLoad(eltAddr);
llvm::Value *elt = CreateCoercedLoad(
eltAddr,
unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
: unpaddedCoercionType,
*this);
if (ArgHasMaybeUndefAttr)
elt = Builder.CreateFreeze(elt);
IRCallArgs[IRArgPos++] = elt;
Expand Down
Loading
Loading