@@ -4,6 +4,7 @@ use std::cmp;
4
4
use libc:: c_uint;
5
5
use rustc_abi:: { BackendRepr , HasDataLayout , Primitive , Reg , RegKind , Size } ;
6
6
use rustc_codegen_ssa:: MemFlags ;
7
+ use rustc_codegen_ssa:: common:: TypeKind ;
7
8
use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
8
9
use rustc_codegen_ssa:: mir:: place:: { PlaceRef , PlaceValue } ;
9
10
use rustc_codegen_ssa:: traits:: * ;
@@ -308,7 +309,7 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308
309
}
309
310
310
311
pub ( crate ) trait FnAbiLlvmExt < ' ll , ' tcx > {
311
- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
312
+ fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > , name : & [ u8 ] ) -> & ' ll Type ;
312
313
fn ptr_to_llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
313
314
fn llvm_cconv ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> llvm:: CallConv ;
314
315
@@ -325,26 +326,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325
326
}
326
327
327
328
impl < ' ll , ' tcx > FnAbiLlvmExt < ' ll , ' tcx > for FnAbi < ' tcx , Ty < ' tcx > > {
328
- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type {
329
+ fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > , name : & [ u8 ] ) -> & ' ll Type {
329
330
// Ignore "extra" args from the call site for C variadic functions.
330
331
// Only the "fixed" args are part of the LLVM function signature.
331
332
let args =
332
333
if self . c_variadic { & self . args [ ..self . fixed_count as usize ] } else { & self . args } ;
333
334
335
+ // todo(sayantn): a better way is to look at the `link_name` instead of the function name, because function name can be "faked" using `#[export_name]`
336
+ let llvm_intrinsic = name. starts_with ( b"llvm." )
337
+ && !self . c_variadic
338
+ && self . conv == Conv :: C
339
+ && !self . can_unwind ;
340
+ let amx_intrinsic =
341
+ llvm_intrinsic && name. starts_with ( b"llvm.x86." ) && name. ends_with ( b".internal" ) ;
342
+ let adjust_ty = |ty| {
343
+ // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
344
+ if amx_intrinsic && cx. type_kind ( ty) == TypeKind :: Vector && cx. vector_length ( ty) == 256
345
+ {
346
+ let element_ty = cx. element_type ( ty) ;
347
+ if cx. type_kind ( element_ty) == TypeKind :: Integer && cx. int_width ( element_ty) == 32 {
348
+ return cx. type_x86amx ( ) ;
349
+ }
350
+ }
351
+ ty
352
+ } ;
353
+
334
354
// This capacity calculation is approximate.
335
355
let mut llargument_tys = Vec :: with_capacity (
336
356
self . args . len ( ) + if let PassMode :: Indirect { .. } = self . ret . mode { 1 } else { 0 } ,
337
357
) ;
338
358
339
- let llreturn_ty = match & self . ret . mode {
359
+ let llreturn_ty = adjust_ty ( match & self . ret . mode {
340
360
PassMode :: Ignore => cx. type_void ( ) ,
341
361
PassMode :: Direct ( _) | PassMode :: Pair ( ..) => self . ret . layout . immediate_llvm_type ( cx) ,
342
362
PassMode :: Cast { cast, pad_i32 : _ } => cast. llvm_type ( cx) ,
343
363
PassMode :: Indirect { .. } => {
344
364
llargument_tys. push ( cx. type_ptr ( ) ) ;
345
365
cx. type_void ( )
346
366
}
347
- } ;
367
+ } ) ;
348
368
349
369
for arg in args {
350
370
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +408,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388
408
cast. llvm_type ( cx)
389
409
}
390
410
} ;
391
- llargument_tys. push ( llarg_ty) ;
411
+ llargument_tys. push ( adjust_ty ( llarg_ty) ) ;
392
412
}
393
413
394
414
if self . c_variadic {
0 commit comments