Skip to content

Commit 1fd0c2d

Browse files
committed
Add auto-bitcasts from/to x86amx and i32x256 for AMX intrinsics
1 parent 3ef8e64 commit 1fd0c2d

File tree

12 files changed

+62
-18
lines changed

12 files changed

+62
-18
lines changed

compiler/rustc_codegen_gcc/src/type_of.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Write;
22

3-
use gccjit::{Struct, Type};
3+
use gccjit::{RValue, Struct, Type};
44
use rustc_abi as abi;
55
use rustc_abi::Primitive::*;
66
use rustc_abi::{
@@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
373373
unimplemented!();
374374
}
375375

376-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
376+
fn fn_decl_backend_type(
377+
&self,
378+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
379+
_fn_ptr: RValue<'gcc>,
380+
) -> Type<'gcc> {
377381
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
378382
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
379383
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)

compiler/rustc_codegen_llvm/src/abi.rs

+25-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::cmp;
44
use libc::c_uint;
55
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
66
use rustc_codegen_ssa::MemFlags;
7+
use rustc_codegen_ssa::common::TypeKind;
78
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
89
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
910
use rustc_codegen_ssa::traits::*;
@@ -308,7 +309,7 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309
}
309310

310311
pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
311-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
312+
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type;
312313
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
313314
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;
314315

@@ -325,26 +326,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325326
}
326327

327328
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
328-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
329+
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type {
329330
// Ignore "extra" args from the call site for C variadic functions.
330331
// Only the "fixed" args are part of the LLVM function signature.
331332
let args =
332333
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };
333334

335+
// todo(sayantn): a better way is to look at the `link_name` instead of the function name, because function name can be "faked" using `#[export_name]`
336+
let llvm_intrinsic = name.starts_with(b"llvm.")
337+
&& !self.c_variadic
338+
&& self.conv == Conv::C
339+
&& !self.can_unwind;
340+
let amx_intrinsic =
341+
llvm_intrinsic && name.starts_with(b"llvm.x86.") && name.ends_with(b".internal");
342+
let adjust_ty = |ty| {
343+
// Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
344+
if amx_intrinsic && cx.type_kind(ty) == TypeKind::Vector && cx.vector_length(ty) == 256
345+
{
346+
let element_ty = cx.element_type(ty);
347+
if cx.type_kind(element_ty) == TypeKind::Integer && cx.int_width(element_ty) == 32 {
348+
return cx.type_x86amx();
349+
}
350+
}
351+
ty
352+
};
353+
334354
// This capacity calculation is approximate.
335355
let mut llargument_tys = Vec::with_capacity(
336356
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
337357
);
338358

339-
let llreturn_ty = match &self.ret.mode {
359+
let llreturn_ty = adjust_ty(match &self.ret.mode {
340360
PassMode::Ignore => cx.type_void(),
341361
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
342362
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
343363
PassMode::Indirect { .. } => {
344364
llargument_tys.push(cx.type_ptr());
345365
cx.type_void()
346366
}
347-
};
367+
});
348368

349369
for arg in args {
350370
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +408,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388408
cast.llvm_type(cx)
389409
}
390410
};
391-
llargument_tys.push(llarg_ty);
411+
llargument_tys.push(adjust_ty(llarg_ty));
392412
}
393413

394414
if self.c_variadic {

compiler/rustc_codegen_llvm/src/builder.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14351435
if let Some(fn_abi) = fn_abi {
14361436
fn_abi.apply_attrs_callsite(self, call);
14371437
}
1438-
call
1438+
1439+
if self.cx.type_kind(self.cx.val_ty(call)) == TypeKind::X86_AMX {
1440+
self.bitcast(call, self.cx.type_vector(self.cx.type_i32(), 256))
1441+
} else {
1442+
call
1443+
}
14391444
}
14401445

14411446
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {

compiler/rustc_codegen_llvm/src/declare.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
158158
fn_abi.llvm_cconv(self),
159159
llvm::UnnamedAddr::Global,
160160
llvm::Visibility::Default,
161-
fn_abi.llvm_type(self),
161+
fn_abi.llvm_type(self, name.as_ref()),
162162
);
163163
fn_abi.apply_attrs_llfn(self, llfn, instance);
164164

compiler/rustc_codegen_llvm/src/intrinsic.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ fn gen_fn<'a, 'll, 'tcx>(
10901090
codegen: &mut dyn FnMut(Builder<'a, 'll, 'tcx>),
10911091
) -> (&'ll Type, &'ll Value) {
10921092
let fn_abi = cx.fn_abi_of_fn_ptr(rust_fn_sig, ty::List::empty());
1093-
let llty = fn_abi.llvm_type(cx);
1093+
let llty = fn_abi.llvm_type(cx, name.as_ref());
10941094
let llfn = cx.declare_fn(name, fn_abi, None);
10951095
cx.set_frame_pointer_type(llfn);
10961096
cx.apply_target_cpu_attr(llfn);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,9 @@ unsafe extern "C" {
10551055
pub(crate) fn LLVMPointerTypeInContext(C: &Context, AddressSpace: c_uint) -> &Type;
10561056
pub(crate) fn LLVMVectorType(ElementType: &Type, ElementCount: c_uint) -> &Type;
10571057

1058+
// Special X86 Type for AMX
1059+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1060+
10581061
pub(crate) fn LLVMGetElementType(Ty: &Type) -> &Type;
10591062
pub(crate) fn LLVMGetVectorSize(VectorTy: &Type) -> c_uint;
10601063

compiler/rustc_codegen_llvm/src/type_.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
154154
)
155155
}
156156
}
157+
158+
pub(crate) fn type_x86amx(&self) -> &'ll Type {
159+
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
160+
}
157161
}
158162

159163
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -284,8 +288,12 @@ impl<'ll, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
284288
fn cast_backend_type(&self, ty: &CastTarget) -> &'ll Type {
285289
ty.llvm_type(self)
286290
}
287-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
288-
fn_abi.llvm_type(self)
291+
fn fn_decl_backend_type(
292+
&self,
293+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
294+
fn_ptr: &'ll Value,
295+
) -> &'ll Type {
296+
fn_abi.llvm_type(self, llvm::get_value_name(fn_ptr))
289297
}
290298
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
291299
fn_abi.ptr_to_llvm_type(self)

compiler/rustc_codegen_ssa/src/mir/block.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {
187187

188188
// If there is a cleanup block and the function we're calling can unwind, then
189189
// do an invoke, otherwise do a call.
190-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
190+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
191191

192192
let fn_attrs = if bx.tcx().def_kind(fx.instance.def_id()).has_codegen_attrs() {
193193
Some(bx.tcx().codegen_fn_attrs(fx.instance.def_id()))
@@ -1806,7 +1806,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
18061806
if is_call_from_compiler_builtins_to_upstream_monomorphization(bx.tcx(), instance) {
18071807
bx.abort();
18081808
} else {
1809-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
1809+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
18101810

18111811
let llret = bx.call(fn_ty, None, Some(fn_abi), fn_ptr, &[], funclet.as_ref(), None);
18121812
bx.apply_attrs_to_cleanup_callsite(llret);

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
779779
};
780780
let fn_ptr = bx.get_fn_addr(instance);
781781
let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty());
782-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
782+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
783783
let fn_attrs = if bx.tcx().def_kind(instance.def_id()).has_codegen_attrs() {
784784
Some(bx.tcx().codegen_fn_attrs(instance.def_id()))
785785
} else {

compiler/rustc_codegen_ssa/src/size_of_val.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub fn size_and_align_of_dst<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
6767
// Generate the call. Cannot use `do_call` since we don't have a MIR terminator so we
6868
// can't create a `TerminationCodegenHelper`. (But we are in good company, this code is
6969
// duplicated plenty of times.)
70-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
70+
let fn_ty = bx.fn_decl_backend_type(fn_abi, llfn);
7171

7272
bx.call(
7373
fn_ty,

compiler/rustc_codegen_ssa/src/traits/type_.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ pub trait LayoutTypeCodegenMethods<'tcx>: BackendTypes {
9696
/// such as when it's stack-allocated or when it's being loaded or stored.
9797
fn backend_type(&self, layout: TyAndLayout<'tcx>) -> Self::Type;
9898
fn cast_backend_type(&self, ty: &CastTarget) -> Self::Type;
99-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
99+
fn fn_decl_backend_type(
100+
&self,
101+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
102+
fn_ptr: Self::Value,
103+
) -> Self::Type;
100104
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
101105
fn reg_backend_type(&self, ty: &Reg) -> Self::Type;
102106
/// The backend type used for a rust type when it's in an SSA register.

compiler/rustc_target/src/target_features.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ pub fn all_rust_features() -> impl Iterator<Item = (&'static str, Stability)> {
786786
// certain size to have their "proper" ABI on each architecture.
787787
// Note that they must be kept sorted by vector size.
788788
const X86_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] =
789-
&[(128, "sse"), (256, "avx"), (512, "avx512f")]; // FIXME: might need changes for AVX10.
789+
&[(128, "sse"), (256, "avx"), (512, "avx512f"), (8192, "amx-tile")];
790790
const AARCH64_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] = &[(128, "neon")];
791791

792792
// We might want to add "helium" too.

0 commit comments

Comments
 (0)