Skip to content

[clang][RISCV] Enable struct of homogeneous scalable vector as function argument #78550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
167 changes: 79 additions & 88 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,8 @@ void ClangToLLVMArgMapping::construct(const ASTContext &Context,
case ABIArgInfo::Direct: {
// FIXME: handle sseregparm someday...
llvm::StructType *STy = dyn_cast<llvm::StructType>(AI.getCoerceToType());
if (AI.isDirect() && AI.getCanBeFlattened() && STy) {
if (AI.isDirect() && AI.getCanBeFlattened() && STy &&
!STy->containsHomogeneousScalableVectorTypes()) {
IRArgs.NumberOfArgs = STy->getNumElements();
} else {
IRArgs.NumberOfArgs = 1;
Expand Down Expand Up @@ -1713,7 +1714,8 @@ CodeGenTypes::GetFunctionType(const CGFunctionInfo &FI) {
// FCAs, so we flatten them if this is safe to do for this argument.
llvm::Type *argType = ArgInfo.getCoerceToType();
llvm::StructType *st = dyn_cast<llvm::StructType>(argType);
if (st && ArgInfo.isDirect() && ArgInfo.getCanBeFlattened()) {
if (st && ArgInfo.isDirect() && ArgInfo.getCanBeFlattened() &&
!st->containsHomogeneousScalableVectorTypes()) {
assert(NumIRArgs == st->getNumElements());
for (unsigned i = 0, e = st->getNumElements(); i != e; ++i)
ArgTypes[FirstIRArg + i] = st->getElementType(i);
Expand Down Expand Up @@ -3206,6 +3208,25 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
}
}

llvm::StructType *STy =
dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
llvm::TypeSize StructSize;
llvm::TypeSize PtrElementSize;
if (ArgI.isDirect() && ArgI.getCanBeFlattened() && STy &&
STy->getNumElements() > 1) {
StructSize = CGM.getDataLayout().getTypeAllocSize(STy);
PtrElementSize =
CGM.getDataLayout().getTypeAllocSize(ConvertTypeForMem(Ty));
if (STy->containsHomogeneousScalableVectorTypes()) {
assert(StructSize == PtrElementSize &&
"Only allow non-fractional movement of structure with"
"homogeneous scalable vector type");

ArgVals.push_back(ParamValue::forDirect(AI));
break;
}
}

Address Alloca = CreateMemTemp(Ty, getContext().getDeclAlign(Arg),
Arg->getName());

Expand All @@ -3214,53 +3235,29 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,

// Fast-isel and the optimizer generally like scalar values better than
// FCAs, so we flatten them if this is safe to do for this argument.
llvm::StructType *STy = dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
if (ArgI.isDirect() && ArgI.getCanBeFlattened() && STy &&
STy->getNumElements() > 1) {
llvm::TypeSize StructSize = CGM.getDataLayout().getTypeAllocSize(STy);
llvm::TypeSize PtrElementSize =
CGM.getDataLayout().getTypeAllocSize(Ptr.getElementType());
if (StructSize.isScalable()) {
assert(STy->containsHomogeneousScalableVectorTypes() &&
"ABI only supports structure with homogeneous scalable vector "
"type");
assert(StructSize == PtrElementSize &&
"Only allow non-fractional movement of structure with"
"homogeneous scalable vector type");
assert(STy->getNumElements() == NumIRArgs);

llvm::Value *LoadedStructValue = llvm::PoisonValue::get(STy);
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
auto *AI = Fn->getArg(FirstIRArg + i);
AI->setName(Arg->getName() + ".coerce" + Twine(i));
LoadedStructValue =
Builder.CreateInsertValue(LoadedStructValue, AI, i);
}
uint64_t SrcSize = StructSize.getFixedValue();
uint64_t DstSize = PtrElementSize.getFixedValue();

Builder.CreateStore(LoadedStructValue, Ptr);
Address AddrToStoreInto = Address::invalid();
if (SrcSize <= DstSize) {
AddrToStoreInto = Ptr.withElementType(STy);
} else {
uint64_t SrcSize = StructSize.getFixedValue();
uint64_t DstSize = PtrElementSize.getFixedValue();

Address AddrToStoreInto = Address::invalid();
if (SrcSize <= DstSize) {
AddrToStoreInto = Ptr.withElementType(STy);
} else {
AddrToStoreInto =
CreateTempAlloca(STy, Alloca.getAlignment(), "coerce");
}
AddrToStoreInto =
CreateTempAlloca(STy, Alloca.getAlignment(), "coerce");
}

assert(STy->getNumElements() == NumIRArgs);
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
auto AI = Fn->getArg(FirstIRArg + i);
AI->setName(Arg->getName() + ".coerce" + Twine(i));
Address EltPtr = Builder.CreateStructGEP(AddrToStoreInto, i);
Builder.CreateStore(AI, EltPtr);
}
assert(STy->getNumElements() == NumIRArgs);
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
auto AI = Fn->getArg(FirstIRArg + i);
AI->setName(Arg->getName() + ".coerce" + Twine(i));
Address EltPtr = Builder.CreateStructGEP(AddrToStoreInto, i);
Builder.CreateStore(AI, EltPtr);
}

if (SrcSize > DstSize) {
Builder.CreateMemCpy(Ptr, AddrToStoreInto, DstSize);
}
if (SrcSize > DstSize) {
Builder.CreateMemCpy(Ptr, AddrToStoreInto, DstSize);
}
} else {
// Simple case, just do a coerced store of the argument into the alloca.
Expand Down Expand Up @@ -5277,6 +5274,24 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
break;
}

llvm::StructType *STy =
dyn_cast<llvm::StructType>(ArgInfo.getCoerceToType());
llvm::Type *SrcTy = ConvertTypeForMem(I->Ty);
llvm::TypeSize SrcTypeSize;
llvm::TypeSize DstTypeSize;
if (STy && ArgInfo.isDirect() && ArgInfo.getCanBeFlattened()) {
SrcTypeSize = CGM.getDataLayout().getTypeAllocSize(SrcTy);
DstTypeSize = CGM.getDataLayout().getTypeAllocSize(STy);
if (STy->containsHomogeneousScalableVectorTypes()) {
assert(SrcTypeSize == DstTypeSize &&
"Only allow non-fractional movement of structure with "
"homogeneous scalable vector type");

IRCallArgs[FirstIRArg] = I->getKnownRValue().getScalarVal();
break;
}
}

// FIXME: Avoid the conversion through memory if possible.
Address Src = Address::invalid();
if (!I->isAggregate()) {
Expand All @@ -5292,54 +5307,30 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,

// Fast-isel and the optimizer generally like scalar values better than
// FCAs, so we flatten them if this is safe to do for this argument.
llvm::StructType *STy =
dyn_cast<llvm::StructType>(ArgInfo.getCoerceToType());
if (STy && ArgInfo.isDirect() && ArgInfo.getCanBeFlattened()) {
llvm::Type *SrcTy = Src.getElementType();
llvm::TypeSize SrcTypeSize =
CGM.getDataLayout().getTypeAllocSize(SrcTy);
llvm::TypeSize DstTypeSize = CGM.getDataLayout().getTypeAllocSize(STy);
if (SrcTypeSize.isScalable()) {
assert(STy->containsHomogeneousScalableVectorTypes() &&
"ABI only supports structure with homogeneous scalable vector "
"type");
assert(SrcTypeSize == DstTypeSize &&
"Only allow non-fractional movement of structure with "
"homogeneous scalable vector type");
assert(NumIRArgs == STy->getNumElements());

llvm::Value *StoredStructValue =
Builder.CreateLoad(Src, Src.getName() + ".tuple");
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
llvm::Value *Extract = Builder.CreateExtractValue(
StoredStructValue, i, Src.getName() + ".extract" + Twine(i));
IRCallArgs[FirstIRArg + i] = Extract;
}
uint64_t SrcSize = SrcTypeSize.getFixedValue();
uint64_t DstSize = DstTypeSize.getFixedValue();

// If the source type is smaller than the destination type of the
// coerce-to logic, copy the source value into a temp alloca the size
// of the destination type to allow loading all of it. The bits past
// the source value are left undef.
if (SrcSize < DstSize) {
Address TempAlloca = CreateTempAlloca(STy, Src.getAlignment(),
Src.getName() + ".coerce");
Builder.CreateMemCpy(TempAlloca, Src, SrcSize);
Src = TempAlloca;
} else {
uint64_t SrcSize = SrcTypeSize.getFixedValue();
uint64_t DstSize = DstTypeSize.getFixedValue();

// If the source type is smaller than the destination type of the
// coerce-to logic, copy the source value into a temp alloca the size
// of the destination type to allow loading all of it. The bits past
// the source value are left undef.
if (SrcSize < DstSize) {
Address TempAlloca = CreateTempAlloca(STy, Src.getAlignment(),
Src.getName() + ".coerce");
Builder.CreateMemCpy(TempAlloca, Src, SrcSize);
Src = TempAlloca;
} else {
Src = Src.withElementType(STy);
}
Src = Src.withElementType(STy);
}

assert(NumIRArgs == STy->getNumElements());
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
Address EltPtr = Builder.CreateStructGEP(Src, i);
llvm::Value *LI = Builder.CreateLoad(EltPtr);
if (ArgHasMaybeUndefAttr)
LI = Builder.CreateFreeze(LI);
IRCallArgs[FirstIRArg + i] = LI;
}
assert(NumIRArgs == STy->getNumElements());
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
Address EltPtr = Builder.CreateStructGEP(Src, i);
llvm::Value *LI = Builder.CreateLoad(EltPtr);
if (ArgHasMaybeUndefAttr)
LI = Builder.CreateFreeze(LI);
IRCallArgs[FirstIRArg + i] = LI;
}
} else {
// In the simple case, just pass the coerced loaded value.
Expand Down
Loading