Skip to content

Add auto-bitcasts between x86amx and i32x256 for AMX intrinsics #140763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions compiler/rustc_codegen_gcc/src/type_of.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Write;

use gccjit::{Struct, Type};
use gccjit::{RValue, Struct, Type};
use rustc_abi as abi;
use rustc_abi::Primitive::*;
use rustc_abi::{
Expand Down Expand Up @@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
unimplemented!();
}

fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_fn_ptr: RValue<'gcc>,
) -> Type<'gcc> {
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)
Expand Down
35 changes: 30 additions & 5 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::cmp;
use libc::c_uint;
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
Expand Down Expand Up @@ -308,7 +309,12 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}

pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
is_llvm_intrinsic: bool,
) -> &'ll Type;
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;

Expand All @@ -325,26 +331,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
}

impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
is_llvm_intrinsic: bool,
) -> &'ll Type {
// Ignore "extra" args from the call site for C variadic functions.
// Only the "fixed" args are part of the LLVM function signature.
let args =
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };

let amx_intrinsic =
is_llvm_intrinsic && name.starts_with(b"llvm.x86.") && name.ends_with(b".internal");
let adjust_ty = |ty| {
// Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
if amx_intrinsic && cx.type_kind(ty) == TypeKind::Vector && cx.vector_length(ty) == 256
{
let element_ty = cx.element_type(ty);
if cx.type_kind(element_ty) == TypeKind::Integer && cx.int_width(element_ty) == 32 {
return cx.type_x86amx();
}
}
ty
};

// This capacity calculation is approximate.
let mut llargument_tys = Vec::with_capacity(
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
);

let llreturn_ty = match &self.ret.mode {
let llreturn_ty = adjust_ty(match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => {
llargument_tys.push(cx.type_ptr());
cx.type_void()
}
};
});

for arg in args {
// Note that the exact number of arguments pushed here is carefully synchronized with
Expand Down Expand Up @@ -388,7 +413,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
cast.llvm_type(cx)
}
};
llargument_tys.push(llarg_ty);
llargument_tys.push(adjust_ty(llarg_ty));
}

if self.c_variadic {
Expand Down
7 changes: 6 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
if let Some(fn_abi) = fn_abi {
fn_abi.apply_attrs_callsite(self, call);
}
call

if self.cx.type_kind(self.cx.val_ty(call)) == TypeKind::X86_AMX {
self.bitcast(call, self.cx.type_vector(self.cx.type_i32(), 256))
} else {
call
}
}

fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
Expand Down
7 changes: 5 additions & 2 deletions compiler/rustc_codegen_llvm/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! and methods are represented as just a fn ptr and not a full
//! closure.
use rustc_codegen_ssa::common;
use rustc_codegen_ssa::{base, common};
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv};
use rustc_middle::ty::{self, Instance, TypeVisitableExt};
use tracing::debug;
Expand Down Expand Up @@ -36,6 +36,8 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
llfn
} else {
let instance_def_id = instance.def_id();
let is_llvm_intrinsic = base::is_llvm_intrinsic(tcx, instance_def_id);

let llfn = if tcx.sess.target.arch == "x86"
&& let Some(dllimport) = crate::common::get_dllimport(tcx, instance_def_id, sym)
{
Expand All @@ -53,6 +55,7 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
),
fn_abi,
Some(instance),
is_llvm_intrinsic,
);

// Fix for https://github.com/rust-lang/rust/issues/104453
Expand All @@ -69,7 +72,7 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
llvm::set_dllimport_storage_class(llfn);
llfn
} else {
cx.declare_fn(sym, fn_abi, Some(instance))
cx.declare_fn(sym, fn_abi, Some(instance), is_llvm_intrinsic)
};
debug!("get_fn: not casting pointer!");

Expand Down
7 changes: 6 additions & 1 deletion compiler/rustc_codegen_llvm/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ fn check_and_apply_linkage<'ll, 'tcx>(
let fn_sig = sig.with(*header);

let fn_abi = cx.fn_abi_of_fn_ptr(fn_sig, ty::List::empty());
cx.declare_fn(sym, &fn_abi, None)
cx.declare_fn(
sym,
&fn_abi,
None,
rustc_codegen_ssa::base::is_llvm_intrinsic(cx.tcx, def_id),
)
} else {
cx.declare_global(sym, cx.type_i8())
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
name: &str,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
instance: Option<Instance<'tcx>>,
is_llvm_intrinsic: bool,
) -> &'ll Value {
debug!("declare_rust_fn(name={:?}, fn_abi={:?})", name, fn_abi);

Expand All @@ -158,7 +159,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
fn_abi.llvm_cconv(self),
llvm::UnnamedAddr::Global,
llvm::Visibility::Default,
fn_abi.llvm_type(self),
fn_abi.llvm_type(self, name.as_ref(), is_llvm_intrinsic),
);
fn_abi.apply_attrs_llfn(self, llfn, instance);

Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1088,10 +1088,11 @@ fn gen_fn<'a, 'll, 'tcx>(
name: &str,
rust_fn_sig: ty::PolyFnSig<'tcx>,
codegen: &mut dyn FnMut(Builder<'a, 'll, 'tcx>),
is_llvm_intrinsic: bool,
) -> (&'ll Type, &'ll Value) {
let fn_abi = cx.fn_abi_of_fn_ptr(rust_fn_sig, ty::List::empty());
let llty = fn_abi.llvm_type(cx);
let llfn = cx.declare_fn(name, fn_abi, None);
let llty = fn_abi.llvm_type(cx, name.as_ref(), is_llvm_intrinsic);
let llfn = cx.declare_fn(name, fn_abi, None, is_llvm_intrinsic);
cx.set_frame_pointer_type(llfn);
cx.apply_target_cpu_attr(llfn);
// FIXME(eddyb) find a nicer way to do this.
Expand Down Expand Up @@ -1147,7 +1148,7 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
hir::Safety::Unsafe,
ExternAbi::Rust,
));
let rust_try = gen_fn(cx, "__rust_try", rust_fn_sig, codegen);
let rust_try = gen_fn(cx, "__rust_try", rust_fn_sig, codegen, false);
cx.rust_try_fn.set(Some(rust_try));
rust_try
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,9 @@ unsafe extern "C" {
pub(crate) fn LLVMPointerTypeInContext(C: &Context, AddressSpace: c_uint) -> &Type;
pub(crate) fn LLVMVectorType(ElementType: &Type, ElementCount: c_uint) -> &Type;

// Special X86 Type for AMX
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;

pub(crate) fn LLVMGetElementType(Ty: &Type) -> &Type;
pub(crate) fn LLVMGetVectorSize(VectorTy: &Type) -> c_uint;

Expand Down Expand Up @@ -1177,6 +1180,7 @@ unsafe extern "C" {

// Operations on functions
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
pub(crate) fn LLVMGetIntrinsicID(Fn: &Value) -> c_uint;

// Operations on parameters
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
Expand Down
7 changes: 6 additions & 1 deletion compiler/rustc_codegen_llvm/src/mono_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
assert!(!instance.args.has_infer());

let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());
let lldecl = self.declare_fn(symbol_name, fn_abi, Some(instance));
let lldecl = self.declare_fn(
symbol_name,
fn_abi,
Some(instance),
rustc_codegen_ssa::base::is_llvm_intrinsic(self.tcx, instance.def_id()),
);
llvm::set_linkage(lldecl, base::linkage_to_llvm(linkage));
let attrs = self.tcx.codegen_fn_attrs(instance.def_id());
base::set_link_section(lldecl, attrs);
Expand Down
14 changes: 12 additions & 2 deletions compiler/rustc_codegen_llvm/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
)
}
}

pub(crate) fn type_x86amx(&self) -> &'ll Type {
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
}
}

impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
Expand Down Expand Up @@ -284,8 +288,14 @@ impl<'ll, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
fn cast_backend_type(&self, ty: &CastTarget) -> &'ll Type {
ty.llvm_type(self)
}
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
fn_abi.llvm_type(self)
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
fn_ptr: &'ll Value,
) -> &'ll Type {
let intrinsic_id = unsafe { llvm::LLVMGetIntrinsicID(fn_ptr) };
// When the function is not an intrinsic, `Intrinsic::getIntrinsicID` returns `Intrinsic::not_intrinsic`, which is always defined to be 0
fn_abi.llvm_type(self, llvm::get_value_name(fn_ptr), intrinsic_id != 0)
}
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
fn_abi.ptr_to_llvm_type(self)
Expand Down
16 changes: 8 additions & 8 deletions compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,14 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
ongoing_codegen
}

pub fn is_llvm_intrinsic(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
if let Some(name) = tcx.codegen_fn_attrs(def_id).link_name {
name.as_str().starts_with("llvm.")
} else {
false
}
}

/// Returns whether a call from the current crate to the [`Instance`] would produce a call
/// from `compiler_builtins` to a symbol the linker must resolve.
///
Expand All @@ -927,14 +935,6 @@ pub fn is_call_from_compiler_builtins_to_upstream_monomorphization<'tcx>(
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
) -> bool {
fn is_llvm_intrinsic(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
if let Some(name) = tcx.codegen_fn_attrs(def_id).link_name {
name.as_str().starts_with("llvm.")
} else {
false
}
}

let def_id = instance.def_id();
!def_id.is_local()
&& tcx.is_compiler_builtins(LOCAL_CRATE)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {

// If there is a cleanup block and the function we're calling can unwind, then
// do an invoke, otherwise do a call.
let fn_ty = bx.fn_decl_backend_type(fn_abi);
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);

let fn_attrs = if bx.tcx().def_kind(fx.instance.def_id()).has_codegen_attrs() {
Some(bx.tcx().codegen_fn_attrs(fx.instance.def_id()))
Expand Down Expand Up @@ -1806,7 +1806,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
if is_call_from_compiler_builtins_to_upstream_monomorphization(bx.tcx(), instance) {
bx.abort();
} else {
let fn_ty = bx.fn_decl_backend_type(fn_abi);
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);

let llret = bx.call(fn_ty, None, Some(fn_abi), fn_ptr, &[], funclet.as_ref(), None);
bx.apply_attrs_to_cleanup_callsite(llret);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
};
let fn_ptr = bx.get_fn_addr(instance);
let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty());
let fn_ty = bx.fn_decl_backend_type(fn_abi);
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
let fn_attrs = if bx.tcx().def_kind(instance.def_id()).has_codegen_attrs() {
Some(bx.tcx().codegen_fn_attrs(instance.def_id()))
} else {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/size_of_val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn size_and_align_of_dst<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
// Generate the call. Cannot use `do_call` since we don't have a MIR terminator so we
// can't create a `TerminationCodegenHelper`. (But we are in good company, this code is
// duplicated plenty of times.)
let fn_ty = bx.fn_decl_backend_type(fn_abi);
let fn_ty = bx.fn_decl_backend_type(fn_abi, llfn);

bx.call(
fn_ty,
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_codegen_ssa/src/traits/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ pub trait LayoutTypeCodegenMethods<'tcx>: BackendTypes {
/// such as when it's stack-allocated or when it's being loaded or stored.
fn backend_type(&self, layout: TyAndLayout<'tcx>) -> Self::Type;
fn cast_backend_type(&self, ty: &CastTarget) -> Self::Type;
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
fn_ptr: Self::Value,
) -> Self::Type;
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
fn reg_backend_type(&self, ty: &Reg) -> Self::Type;
/// The backend type used for a rust type when it's in an SSA register.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_target/src/target_features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ pub fn all_rust_features() -> impl Iterator<Item = (&'static str, Stability)> {
// certain size to have their "proper" ABI on each architecture.
// Note that they must be kept sorted by vector size.
const X86_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] =
&[(128, "sse"), (256, "avx"), (512, "avx512f")]; // FIXME: might need changes for AVX10.
&[(128, "sse"), (256, "avx"), (512, "avx512f"), (8192, "amx-tile")];
const AARCH64_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] = &[(128, "neon")];

// We might want to add "helium" too.
Expand Down
Loading