@@ -46,6 +46,18 @@ static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
46
46
47
47
namespace {
48
48
49
+ enum ThunkArgTranslation : uint8_t {
50
+ Direct,
51
+ Bitcast,
52
+ PointerIndirection,
53
+ };
54
+
55
+ struct ThunkArgInfo {
56
+ Type *Arm64Ty;
57
+ Type *X64Ty;
58
+ ThunkArgTranslation Translation;
59
+ };
60
+
49
61
class AArch64Arm64ECCallLowering : public ModulePass {
50
62
public:
51
63
static char ID;
@@ -74,25 +86,30 @@ class AArch64Arm64ECCallLowering : public ModulePass {
74
86
75
87
void getThunkType (FunctionType *FT, AttributeList AttrList,
76
88
Arm64ECThunkType TT, raw_ostream &Out,
77
- FunctionType *&Arm64Ty, FunctionType *&X64Ty);
89
+ FunctionType *&Arm64Ty, FunctionType *&X64Ty,
90
+ SmallVector<ThunkArgTranslation> &ArgTranslations);
78
91
void getThunkRetType (FunctionType *FT, AttributeList AttrList,
79
92
raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
80
93
SmallVectorImpl<Type *> &Arm64ArgTypes,
81
- SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
94
+ SmallVectorImpl<Type *> &X64ArgTypes,
95
+ SmallVector<ThunkArgTranslation> &ArgTranslations,
96
+ bool &HasSretPtr);
82
97
void getThunkArgTypes (FunctionType *FT, AttributeList AttrList,
83
98
Arm64ECThunkType TT, raw_ostream &Out,
84
99
SmallVectorImpl<Type *> &Arm64ArgTypes,
85
- SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86
- void canonicalizeThunkType (Type *T, Align Alignment, bool Ret,
87
- uint64_t ArgSizeBytes, raw_ostream &Out,
88
- Type *&Arm64Ty, Type *&X64Ty);
100
+ SmallVectorImpl<Type *> &X64ArgTypes,
101
+ SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
102
+ bool HasSretPtr);
103
+ ThunkArgInfo canonicalizeThunkType (Type *T, Align Alignment, bool Ret,
104
+ uint64_t ArgSizeBytes, raw_ostream &Out);
89
105
};
90
106
91
107
} // end anonymous namespace
92
108
93
109
void AArch64Arm64ECCallLowering::getThunkType (
94
110
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
95
- raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
111
+ raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
112
+ SmallVector<ThunkArgTranslation> &ArgTranslations) {
96
113
Out << (TT == Arm64ECThunkType::Entry ? " $ientry_thunk$cdecl$"
97
114
: " $iexit_thunk$cdecl$" );
98
115
@@ -111,10 +128,10 @@ void AArch64Arm64ECCallLowering::getThunkType(
111
128
112
129
bool HasSretPtr = false ;
113
130
getThunkRetType (FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114
- X64ArgTypes, HasSretPtr);
131
+ X64ArgTypes, ArgTranslations, HasSretPtr);
115
132
116
133
getThunkArgTypes (FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117
- HasSretPtr);
134
+ ArgTranslations, HasSretPtr);
118
135
119
136
Arm64Ty = FunctionType::get (Arm64RetTy, Arm64ArgTypes, false );
120
137
@@ -124,7 +141,8 @@ void AArch64Arm64ECCallLowering::getThunkType(
124
141
void AArch64Arm64ECCallLowering::getThunkArgTypes (
125
142
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
126
143
raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127
- SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
144
+ SmallVectorImpl<Type *> &X64ArgTypes,
145
+ SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
128
146
129
147
Out << " $" ;
130
148
if (FT->isVarArg ()) {
@@ -153,17 +171,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
153
171
for (int i = HasSretPtr ? 1 : 0 ; i < 4 ; i++) {
154
172
Arm64ArgTypes.push_back (I64Ty);
155
173
X64ArgTypes.push_back (I64Ty);
174
+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
156
175
}
157
176
158
177
// x4
159
178
Arm64ArgTypes.push_back (PtrTy);
160
179
X64ArgTypes.push_back (PtrTy);
180
+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
161
181
// x5
162
182
Arm64ArgTypes.push_back (I64Ty);
163
183
if (TT != Arm64ECThunkType::Entry) {
164
184
// FIXME: x5 isn't actually used by the x64 side; revisit once we
165
185
// have proper isel for varargs
166
186
X64ArgTypes.push_back (I64Ty);
187
+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
167
188
}
168
189
return ;
169
190
}
@@ -187,18 +208,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
187
208
uint64_t ArgSizeBytes = 0 ;
188
209
Align ParamAlign = Align ();
189
210
#endif
190
- Type * Arm64Ty, * X64Ty;
191
- canonicalizeThunkType (FT->getParamType (I), ParamAlign,
192
- /* Ret*/ false , ArgSizeBytes, Out, Arm64Ty, X64Ty );
211
+ auto [ Arm64Ty, X64Ty, ArgTranslation] =
212
+ canonicalizeThunkType (FT->getParamType (I), ParamAlign,
213
+ /* Ret*/ false , ArgSizeBytes, Out);
193
214
Arm64ArgTypes.push_back (Arm64Ty);
194
215
X64ArgTypes.push_back (X64Ty);
216
+ ArgTranslations.push_back (ArgTranslation);
195
217
}
196
218
}
197
219
198
220
void AArch64Arm64ECCallLowering::getThunkRetType (
199
221
FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200
222
Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201
- SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
223
+ SmallVectorImpl<Type *> &X64ArgTypes,
224
+ SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
202
225
Type *T = FT->getReturnType ();
203
226
#if 0
204
227
// FIXME: Need more information about argument size; see
@@ -240,13 +263,13 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
240
263
// that's a miscompile.)
241
264
Type *SRetType = SRetAttr0.getValueAsType ();
242
265
Align SRetAlign = AttrList.getParamAlignment (0 ).valueOrOne ();
243
- Type *Arm64Ty, *X64Ty;
244
266
canonicalizeThunkType (SRetType, SRetAlign, /* Ret*/ true , ArgSizeBytes,
245
- Out, Arm64Ty, X64Ty );
267
+ Out);
246
268
Arm64RetTy = VoidTy;
247
269
X64RetTy = VoidTy;
248
270
Arm64ArgTypes.push_back (FT->getParamType (0 ));
249
271
X64ArgTypes.push_back (FT->getParamType (0 ));
272
+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
250
273
HasSretPtr = true ;
251
274
return ;
252
275
}
@@ -258,8 +281,10 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
258
281
return ;
259
282
}
260
283
261
- canonicalizeThunkType (T, Align (), /* Ret*/ true , ArgSizeBytes, Out, Arm64RetTy,
262
- X64RetTy);
284
+ auto info =
285
+ canonicalizeThunkType (T, Align (), /* Ret*/ true , ArgSizeBytes, Out);
286
+ Arm64RetTy = info.Arm64Ty ;
287
+ X64RetTy = info.X64Ty ;
263
288
if (X64RetTy->isPointerTy ()) {
264
289
// If the X64 type is canonicalized to a pointer, that means it's
265
290
// passed/returned indirectly. For a return value, that means it's an
@@ -269,21 +294,33 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
269
294
}
270
295
}
271
296
272
- void AArch64Arm64ECCallLowering::canonicalizeThunkType (
273
- Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
274
- Type *&Arm64Ty, Type *&X64Ty) {
297
+ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType (
298
+ Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
299
+ raw_ostream &Out) {
300
+
301
+ auto direct = [](Type *T) {
302
+ return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
303
+ };
304
+
305
+ auto bitcast = [this ](Type *Arm64Ty, uint64_t SizeInBytes) {
306
+ return ThunkArgInfo{Arm64Ty,
307
+ llvm::Type::getIntNTy (M->getContext (), SizeInBytes * 8 ),
308
+ ThunkArgTranslation::Bitcast};
309
+ };
310
+
311
+ auto pointerIndirection = [this ](Type *Arm64Ty) {
312
+ return ThunkArgInfo{Arm64Ty, PtrTy,
313
+ ThunkArgTranslation::PointerIndirection};
314
+ };
315
+
275
316
if (T->isFloatTy ()) {
276
317
Out << " f" ;
277
- Arm64Ty = T;
278
- X64Ty = T;
279
- return ;
318
+ return direct (T);
280
319
}
281
320
282
321
if (T->isDoubleTy ()) {
283
322
Out << " d" ;
284
- Arm64Ty = T;
285
- X64Ty = T;
286
- return ;
323
+ return direct (T);
287
324
}
288
325
289
326
if (T->isFloatingPointTy ()) {
@@ -306,16 +343,14 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
306
343
Out << (ElementTy->isFloatTy () ? " F" : " D" ) << TotalSizeBytes;
307
344
if (Alignment.value () >= 16 && !Ret)
308
345
Out << " a" << Alignment.value ();
309
- Arm64Ty = T;
310
346
if (TotalSizeBytes <= 8 ) {
311
347
// Arm64 returns small structs of float/double in float registers;
312
348
// X64 uses RAX.
313
- X64Ty = llvm::Type::getIntNTy (M-> getContext () , TotalSizeBytes * 8 );
349
+ return bitcast (T , TotalSizeBytes);
314
350
} else {
315
351
// Struct is passed directly on Arm64, but indirectly on X64.
316
- X64Ty = PtrTy ;
352
+ return pointerIndirection (T) ;
317
353
}
318
- return ;
319
354
} else if (T->isFloatingPointTy ()) {
320
355
report_fatal_error (" Only 32 and 64 bit floating points are supported for "
321
356
" ARM64EC thunks" );
@@ -324,9 +359,7 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
324
359
325
360
if ((T->isIntegerTy () || T->isPointerTy ()) && DL.getTypeSizeInBits (T) <= 64 ) {
326
361
Out << " i8" ;
327
- Arm64Ty = I64Ty;
328
- X64Ty = I64Ty;
329
- return ;
362
+ return direct (I64Ty);
330
363
}
331
364
332
365
unsigned TypeSize = ArgSizeBytes;
@@ -338,13 +371,12 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
338
371
if (Alignment.value () >= 16 && !Ret)
339
372
Out << " a" << Alignment.value ();
340
373
// FIXME: Try to canonicalize Arm64Ty more thoroughly?
341
- Arm64Ty = T;
342
374
if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8 ) {
343
375
// Pass directly in an integer register
344
- X64Ty = llvm::Type::getIntNTy (M-> getContext () , TypeSize * 8 );
376
+ return bitcast (T , TypeSize);
345
377
} else {
346
378
// Passed directly on Arm64, but indirectly on X64.
347
- X64Ty = PtrTy ;
379
+ return pointerIndirection (T) ;
348
380
}
349
381
}
350
382
@@ -355,8 +387,9 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
355
387
SmallString<256 > ExitThunkName;
356
388
llvm::raw_svector_ostream ExitThunkStream (ExitThunkName);
357
389
FunctionType *Arm64Ty, *X64Ty;
390
+ SmallVector<ThunkArgTranslation> ArgTranslations;
358
391
getThunkType (FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
359
- X64Ty);
392
+ X64Ty, ArgTranslations );
360
393
if (Function *F = M->getFunction (ExitThunkName))
361
394
return F;
362
395
@@ -387,6 +420,7 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
387
420
SmallVector<Value *> Args;
388
421
389
422
// Pass the called function in x9.
423
+ auto X64TyOffset = 1 ;
390
424
Args.push_back (F->arg_begin ());
391
425
392
426
Type *RetTy = Arm64Ty->getReturnType ();
@@ -396,10 +430,14 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
396
430
// pointer.
397
431
if (DL.getTypeStoreSize (RetTy) > 8 ) {
398
432
Args.push_back (IRB.CreateAlloca (RetTy));
433
+ X64TyOffset++;
399
434
}
400
435
}
401
436
402
- for (auto &Arg : make_range (F->arg_begin () + 1 , F->arg_end ())) {
437
+ for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal (
438
+ make_range (F->arg_begin () + 1 , F->arg_end ()),
439
+ make_range (X64Ty->param_begin () + X64TyOffset, X64Ty->param_end ()),
440
+ ArgTranslations)) {
403
441
// Translate arguments from AArch64 calling convention to x86 calling
404
442
// convention.
405
443
//
@@ -414,18 +452,20 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
414
452
// with an attribute.)
415
453
//
416
454
// The first argument is the called function, stored in x9.
417
- if (Arg.getType ()->isArrayTy () || Arg.getType ()->isStructTy () ||
418
- DL.getTypeStoreSize (Arg.getType ()) > 8 ) {
455
+ if (ArgTranslation != ThunkArgTranslation::Direct) {
419
456
Value *Mem = IRB.CreateAlloca (Arg.getType ());
420
457
IRB.CreateStore (&Arg, Mem);
421
- if (DL. getTypeStoreSize (Arg. getType ()) <= 8 ) {
458
+ if (ArgTranslation == ThunkArgTranslation::Bitcast ) {
422
459
Type *IntTy = IRB.getIntNTy (DL.getTypeStoreSizeInBits (Arg.getType ()));
423
460
Args.push_back (IRB.CreateLoad (IntTy, IRB.CreateBitCast (Mem, PtrTy)));
424
- } else
461
+ } else {
462
+ assert (ArgTranslation == ThunkArgTranslation::PointerIndirection);
425
463
Args.push_back (Mem);
464
+ }
426
465
} else {
427
466
Args.push_back (&Arg);
428
467
}
468
+ assert (Args.back ()->getType () == X64ArgType);
429
469
}
430
470
// FIXME: Transfer necessary attributes? sret? anything else?
431
471
@@ -459,8 +499,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
459
499
SmallString<256 > EntryThunkName;
460
500
llvm::raw_svector_ostream EntryThunkStream (EntryThunkName);
461
501
FunctionType *Arm64Ty, *X64Ty;
502
+ SmallVector<ThunkArgTranslation> ArgTranslations;
462
503
getThunkType (F->getFunctionType (), F->getAttributes (),
463
- Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
504
+ Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
505
+ ArgTranslations);
464
506
if (Function *F = M->getFunction (EntryThunkName))
465
507
return F;
466
508
@@ -472,7 +514,6 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
472
514
// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
473
515
Thunk->addFnAttr (" frame-pointer" , " all" );
474
516
475
- auto &DL = M->getDataLayout ();
476
517
BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , Thunk);
477
518
IRBuilder<> IRB (BB);
478
519
@@ -481,24 +522,28 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
481
522
482
523
bool TransformDirectToSRet = X64RetType->isVoidTy () && !RetTy->isVoidTy ();
483
524
unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1 ;
484
- unsigned PassthroughArgSize = F->isVarArg () ? 5 : Thunk->arg_size ();
525
+ unsigned PassthroughArgSize =
526
+ (F->isVarArg () ? 5 : Thunk->arg_size ()) - ThunkArgOffset;
527
+ assert (ArgTranslations.size () == F->isVarArg () ? 5 : PassthroughArgSize);
485
528
486
529
// Translate arguments to call.
487
530
SmallVector<Value *> Args;
488
- for (unsigned i = ThunkArgOffset, e = PassthroughArgSize ; i != e ; ++i) {
489
- Value *Arg = Thunk->getArg (i);
490
- Type *ArgTy = Arm64Ty->getParamType (i - ThunkArgOffset );
491
- if (ArgTy-> isArrayTy () || ArgTy-> isStructTy () ||
492
- DL. getTypeStoreSize (ArgTy) > 8 ) {
531
+ for (unsigned i = 0 ; i != PassthroughArgSize ; ++i) {
532
+ Value *Arg = Thunk->getArg (i + ThunkArgOffset );
533
+ Type *ArgTy = Arm64Ty->getParamType (i);
534
+ ThunkArgTranslation ArgTranslation = ArgTranslations[i];
535
+ if (ArgTranslation != ThunkArgTranslation::Direct ) {
493
536
// Translate array/struct arguments to the expected type.
494
- if (DL. getTypeStoreSize (ArgTy) <= 8 ) {
537
+ if (ArgTranslation == ThunkArgTranslation::Bitcast ) {
495
538
Value *CastAlloca = IRB.CreateAlloca (ArgTy);
496
539
IRB.CreateStore (Arg, IRB.CreateBitCast (CastAlloca, PtrTy));
497
540
Arg = IRB.CreateLoad (ArgTy, CastAlloca);
498
541
} else {
542
+ assert (ArgTranslation == ThunkArgTranslation::PointerIndirection);
499
543
Arg = IRB.CreateLoad (ArgTy, IRB.CreateBitCast (Arg, PtrTy));
500
544
}
501
545
}
546
+ assert (Arg->getType () == ArgTy);
502
547
Args.push_back (Arg);
503
548
}
504
549
@@ -558,8 +603,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
558
603
Function *AArch64Arm64ECCallLowering::buildGuestExitThunk (Function *F) {
559
604
llvm::raw_null_ostream NullThunkName;
560
605
FunctionType *Arm64Ty, *X64Ty;
606
+ SmallVector<ThunkArgTranslation> ArgTranslations;
561
607
getThunkType (F->getFunctionType (), F->getAttributes (),
562
- Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
608
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
609
+ ArgTranslations);
563
610
auto MangledName = getArm64ECMangledFunctionName (F->getName ().str ());
564
611
assert (MangledName && " Can't guest exit to function that's already native" );
565
612
std::string ThunkName = *MangledName;
0 commit comments