Skip to content

Commit 2c9c22c

Browse files
authored
[ARM64EC] Fix thunks for vector args (#96003)
The checks when building a thunk to decide if an arg needed to be cast to/from an integer or redirected via a pointer didn't match how arg types were changed in `canonicalizeThunkType`, this caused LLVM to ICE when using vector types as args due to incorrect types in a call instruction. Instead of duplicating these checks, we should check if the arg type differs between x64 and AArch64 and then cast or redirect as appropriate.
1 parent 8fa4fe1 commit 2c9c22c

File tree

3 files changed

+326
-53
lines changed

3 files changed

+326
-53
lines changed

llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp

Lines changed: 100 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
4646

4747
namespace {
4848

49+
enum ThunkArgTranslation : uint8_t {
50+
Direct,
51+
Bitcast,
52+
PointerIndirection,
53+
};
54+
55+
struct ThunkArgInfo {
56+
Type *Arm64Ty;
57+
Type *X64Ty;
58+
ThunkArgTranslation Translation;
59+
};
60+
4961
class AArch64Arm64ECCallLowering : public ModulePass {
5062
public:
5163
static char ID;
@@ -74,25 +86,30 @@ class AArch64Arm64ECCallLowering : public ModulePass {
7486

7587
void getThunkType(FunctionType *FT, AttributeList AttrList,
7688
Arm64ECThunkType TT, raw_ostream &Out,
77-
FunctionType *&Arm64Ty, FunctionType *&X64Ty);
89+
FunctionType *&Arm64Ty, FunctionType *&X64Ty,
90+
SmallVector<ThunkArgTranslation> &ArgTranslations);
7891
void getThunkRetType(FunctionType *FT, AttributeList AttrList,
7992
raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
8093
SmallVectorImpl<Type *> &Arm64ArgTypes,
81-
SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
94+
SmallVectorImpl<Type *> &X64ArgTypes,
95+
SmallVector<ThunkArgTranslation> &ArgTranslations,
96+
bool &HasSretPtr);
8297
void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
8398
Arm64ECThunkType TT, raw_ostream &Out,
8499
SmallVectorImpl<Type *> &Arm64ArgTypes,
85-
SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86-
void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
87-
uint64_t ArgSizeBytes, raw_ostream &Out,
88-
Type *&Arm64Ty, Type *&X64Ty);
100+
SmallVectorImpl<Type *> &X64ArgTypes,
101+
SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
102+
bool HasSretPtr);
103+
ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
104+
uint64_t ArgSizeBytes, raw_ostream &Out);
89105
};
90106

91107
} // end anonymous namespace
92108

93109
void AArch64Arm64ECCallLowering::getThunkType(
94110
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
95-
raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
111+
raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
112+
SmallVector<ThunkArgTranslation> &ArgTranslations) {
96113
Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
97114
: "$iexit_thunk$cdecl$");
98115

@@ -111,10 +128,10 @@ void AArch64Arm64ECCallLowering::getThunkType(
111128

112129
bool HasSretPtr = false;
113130
getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114-
X64ArgTypes, HasSretPtr);
131+
X64ArgTypes, ArgTranslations, HasSretPtr);
115132

116133
getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117-
HasSretPtr);
134+
ArgTranslations, HasSretPtr);
118135

119136
Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
120137

@@ -124,7 +141,8 @@ void AArch64Arm64ECCallLowering::getThunkType(
124141
void AArch64Arm64ECCallLowering::getThunkArgTypes(
125142
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
126143
raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127-
SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
144+
SmallVectorImpl<Type *> &X64ArgTypes,
145+
SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
128146

129147
Out << "$";
130148
if (FT->isVarArg()) {
@@ -153,17 +171,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
153171
for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
154172
Arm64ArgTypes.push_back(I64Ty);
155173
X64ArgTypes.push_back(I64Ty);
174+
ArgTranslations.push_back(ThunkArgTranslation::Direct);
156175
}
157176

158177
// x4
159178
Arm64ArgTypes.push_back(PtrTy);
160179
X64ArgTypes.push_back(PtrTy);
180+
ArgTranslations.push_back(ThunkArgTranslation::Direct);
161181
// x5
162182
Arm64ArgTypes.push_back(I64Ty);
163183
if (TT != Arm64ECThunkType::Entry) {
164184
// FIXME: x5 isn't actually used by the x64 side; revisit once we
165185
// have proper isel for varargs
166186
X64ArgTypes.push_back(I64Ty);
187+
ArgTranslations.push_back(ThunkArgTranslation::Direct);
167188
}
168189
return;
169190
}
@@ -187,18 +208,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
187208
uint64_t ArgSizeBytes = 0;
188209
Align ParamAlign = Align();
189210
#endif
190-
Type *Arm64Ty, *X64Ty;
191-
canonicalizeThunkType(FT->getParamType(I), ParamAlign,
192-
/*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
211+
auto [Arm64Ty, X64Ty, ArgTranslation] =
212+
canonicalizeThunkType(FT->getParamType(I), ParamAlign,
213+
/*Ret*/ false, ArgSizeBytes, Out);
193214
Arm64ArgTypes.push_back(Arm64Ty);
194215
X64ArgTypes.push_back(X64Ty);
216+
ArgTranslations.push_back(ArgTranslation);
195217
}
196218
}
197219

198220
void AArch64Arm64ECCallLowering::getThunkRetType(
199221
FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200222
Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201-
SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
223+
SmallVectorImpl<Type *> &X64ArgTypes,
224+
SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
202225
Type *T = FT->getReturnType();
203226
#if 0
204227
// FIXME: Need more information about argument size; see
@@ -240,13 +263,13 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
240263
// that's a miscompile.)
241264
Type *SRetType = SRetAttr0.getValueAsType();
242265
Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
243-
Type *Arm64Ty, *X64Ty;
244266
canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
245-
Out, Arm64Ty, X64Ty);
267+
Out);
246268
Arm64RetTy = VoidTy;
247269
X64RetTy = VoidTy;
248270
Arm64ArgTypes.push_back(FT->getParamType(0));
249271
X64ArgTypes.push_back(FT->getParamType(0));
272+
ArgTranslations.push_back(ThunkArgTranslation::Direct);
250273
HasSretPtr = true;
251274
return;
252275
}
@@ -258,8 +281,10 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
258281
return;
259282
}
260283

261-
canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
262-
X64RetTy);
284+
auto info =
285+
canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
286+
Arm64RetTy = info.Arm64Ty;
287+
X64RetTy = info.X64Ty;
263288
if (X64RetTy->isPointerTy()) {
264289
// If the X64 type is canonicalized to a pointer, that means it's
265290
// passed/returned indirectly. For a return value, that means it's an
@@ -269,21 +294,33 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
269294
}
270295
}
271296

272-
void AArch64Arm64ECCallLowering::canonicalizeThunkType(
273-
Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
274-
Type *&Arm64Ty, Type *&X64Ty) {
297+
ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
298+
Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
299+
raw_ostream &Out) {
300+
301+
auto direct = [](Type *T) {
302+
return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
303+
};
304+
305+
auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
306+
return ThunkArgInfo{Arm64Ty,
307+
llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
308+
ThunkArgTranslation::Bitcast};
309+
};
310+
311+
auto pointerIndirection = [this](Type *Arm64Ty) {
312+
return ThunkArgInfo{Arm64Ty, PtrTy,
313+
ThunkArgTranslation::PointerIndirection};
314+
};
315+
275316
if (T->isFloatTy()) {
276317
Out << "f";
277-
Arm64Ty = T;
278-
X64Ty = T;
279-
return;
318+
return direct(T);
280319
}
281320

282321
if (T->isDoubleTy()) {
283322
Out << "d";
284-
Arm64Ty = T;
285-
X64Ty = T;
286-
return;
323+
return direct(T);
287324
}
288325

289326
if (T->isFloatingPointTy()) {
@@ -306,16 +343,14 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
306343
Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
307344
if (Alignment.value() >= 16 && !Ret)
308345
Out << "a" << Alignment.value();
309-
Arm64Ty = T;
310346
if (TotalSizeBytes <= 8) {
311347
// Arm64 returns small structs of float/double in float registers;
312348
// X64 uses RAX.
313-
X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
349+
return bitcast(T, TotalSizeBytes);
314350
} else {
315351
// Struct is passed directly on Arm64, but indirectly on X64.
316-
X64Ty = PtrTy;
352+
return pointerIndirection(T);
317353
}
318-
return;
319354
} else if (T->isFloatingPointTy()) {
320355
report_fatal_error("Only 32 and 64 bit floating points are supported for "
321356
"ARM64EC thunks");
@@ -324,9 +359,7 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
324359

325360
if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
326361
Out << "i8";
327-
Arm64Ty = I64Ty;
328-
X64Ty = I64Ty;
329-
return;
362+
return direct(I64Ty);
330363
}
331364

332365
unsigned TypeSize = ArgSizeBytes;
@@ -338,13 +371,12 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
338371
if (Alignment.value() >= 16 && !Ret)
339372
Out << "a" << Alignment.value();
340373
// FIXME: Try to canonicalize Arm64Ty more thoroughly?
341-
Arm64Ty = T;
342374
if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
343375
// Pass directly in an integer register
344-
X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
376+
return bitcast(T, TypeSize);
345377
} else {
346378
// Passed directly on Arm64, but indirectly on X64.
347-
X64Ty = PtrTy;
379+
return pointerIndirection(T);
348380
}
349381
}
350382

@@ -355,8 +387,9 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
355387
SmallString<256> ExitThunkName;
356388
llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
357389
FunctionType *Arm64Ty, *X64Ty;
390+
SmallVector<ThunkArgTranslation> ArgTranslations;
358391
getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
359-
X64Ty);
392+
X64Ty, ArgTranslations);
360393
if (Function *F = M->getFunction(ExitThunkName))
361394
return F;
362395

@@ -387,6 +420,7 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
387420
SmallVector<Value *> Args;
388421

389422
// Pass the called function in x9.
423+
auto X64TyOffset = 1;
390424
Args.push_back(F->arg_begin());
391425

392426
Type *RetTy = Arm64Ty->getReturnType();
@@ -396,10 +430,14 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
396430
// pointer.
397431
if (DL.getTypeStoreSize(RetTy) > 8) {
398432
Args.push_back(IRB.CreateAlloca(RetTy));
433+
X64TyOffset++;
399434
}
400435
}
401436

402-
for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
437+
for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
438+
make_range(F->arg_begin() + 1, F->arg_end()),
439+
make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
440+
ArgTranslations)) {
403441
// Translate arguments from AArch64 calling convention to x86 calling
404442
// convention.
405443
//
@@ -414,18 +452,20 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
414452
// with an attribute.)
415453
//
416454
// The first argument is the called function, stored in x9.
417-
if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
418-
DL.getTypeStoreSize(Arg.getType()) > 8) {
455+
if (ArgTranslation != ThunkArgTranslation::Direct) {
419456
Value *Mem = IRB.CreateAlloca(Arg.getType());
420457
IRB.CreateStore(&Arg, Mem);
421-
if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
458+
if (ArgTranslation == ThunkArgTranslation::Bitcast) {
422459
Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
423460
Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
424-
} else
461+
} else {
462+
assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
425463
Args.push_back(Mem);
464+
}
426465
} else {
427466
Args.push_back(&Arg);
428467
}
468+
assert(Args.back()->getType() == X64ArgType);
429469
}
430470
// FIXME: Transfer necessary attributes? sret? anything else?
431471

@@ -459,8 +499,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
459499
SmallString<256> EntryThunkName;
460500
llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
461501
FunctionType *Arm64Ty, *X64Ty;
502+
SmallVector<ThunkArgTranslation> ArgTranslations;
462503
getThunkType(F->getFunctionType(), F->getAttributes(),
463-
Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
504+
Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
505+
ArgTranslations);
464506
if (Function *F = M->getFunction(EntryThunkName))
465507
return F;
466508

@@ -472,7 +514,6 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
472514
// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
473515
Thunk->addFnAttr("frame-pointer", "all");
474516

475-
auto &DL = M->getDataLayout();
476517
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
477518
IRBuilder<> IRB(BB);
478519

@@ -481,24 +522,28 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
481522

482523
bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
483524
unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
484-
unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
525+
unsigned PassthroughArgSize =
526+
(F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
527+
assert(ArgTranslations.size() == F->isVarArg() ? 5 : PassthroughArgSize);
485528

486529
// Translate arguments to call.
487530
SmallVector<Value *> Args;
488-
for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
489-
Value *Arg = Thunk->getArg(i);
490-
Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
491-
if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
492-
DL.getTypeStoreSize(ArgTy) > 8) {
531+
for (unsigned i = 0; i != PassthroughArgSize; ++i) {
532+
Value *Arg = Thunk->getArg(i + ThunkArgOffset);
533+
Type *ArgTy = Arm64Ty->getParamType(i);
534+
ThunkArgTranslation ArgTranslation = ArgTranslations[i];
535+
if (ArgTranslation != ThunkArgTranslation::Direct) {
493536
// Translate array/struct arguments to the expected type.
494-
if (DL.getTypeStoreSize(ArgTy) <= 8) {
537+
if (ArgTranslation == ThunkArgTranslation::Bitcast) {
495538
Value *CastAlloca = IRB.CreateAlloca(ArgTy);
496539
IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
497540
Arg = IRB.CreateLoad(ArgTy, CastAlloca);
498541
} else {
542+
assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
499543
Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
500544
}
501545
}
546+
assert(Arg->getType() == ArgTy);
502547
Args.push_back(Arg);
503548
}
504549

@@ -558,8 +603,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
558603
Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
559604
llvm::raw_null_ostream NullThunkName;
560605
FunctionType *Arm64Ty, *X64Ty;
606+
SmallVector<ThunkArgTranslation> ArgTranslations;
561607
getThunkType(F->getFunctionType(), F->getAttributes(),
562-
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
608+
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
609+
ArgTranslations);
563610
auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
564611
assert(MangledName && "Can't guest exit to function that's already native");
565612
std::string ThunkName = *MangledName;

0 commit comments

Comments
 (0)