Skip to content

Commit 269d384

Browse files
authored
Refactor fnc tt (#133)
* fix wrapper * add tt writes into module * don't pass tt anymore through Enzyme API
1 parent eba256f commit 269d384

File tree

4 files changed

+88
-19
lines changed

4 files changed

+88
-19
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use llvm::{
5555
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock,
5656
};
5757
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
58+
use rustc_ast::expand::typetree::FncTree;
5859
use rustc_codegen_ssa::back::link::ensure_removed;
5960
use rustc_codegen_ssa::back::write::{
6061
BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig,
@@ -1091,6 +1092,24 @@ pub(crate) unsafe fn differentiate(
10911092
llvm::set_loose_types(true);
10921093
}
10931094

1095+
// Before dumping the module, we want all the tt to become part of the module.
1096+
for item in &diff_items {
1097+
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
1098+
let llvm_data_layout =
1099+
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
1100+
.expect("got a non-UTF8 data-layout from LLVM");
1101+
//let input_tts: Vec<TypeTree> =
1102+
// item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect();
1103+
//let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
1104+
let tt: FncTree = FncTree {
1105+
args: item.inputs.clone(),
1106+
ret: item.output.clone(),
1107+
};
1108+
let name = CString::new(item.source.clone()).unwrap();
1109+
let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap();
1110+
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
1111+
}
1112+
10941113
if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() {
10951114
unsafe {
10961115
LLVMDumpModule(llmod);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,49 @@ macro_rules! builder_methods_for_value_instructions {
136136
})+
137137
}
138138
}
139+
pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) {
140+
let inputs = tt.args;
141+
let ret_tt: TypeTree = tt.ret;
142+
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
143+
let llvm_data_layout =
144+
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
145+
.expect("got a non-UTF8 data-layout from LLVM");
146+
let attr_name = "enzyme_type";
147+
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
148+
for (i, &ref input) in inputs.iter().enumerate() {
149+
let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
150+
let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) };
151+
let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) };
152+
unsafe {
153+
let attr = llvm::LLVMCreateStringAttribute(
154+
llcx,
155+
c_attr_name.as_ptr(),
156+
c_attr_name.as_bytes().len() as c_uint,
157+
c_str.as_ptr(),
158+
c_str.to_bytes().len() as c_uint,
159+
);
160+
llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr);
161+
}
162+
unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) };
163+
}
164+
let ret_attr = unsafe {
165+
let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
166+
let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner);
167+
let c_str = std::ffi::CStr::from_ptr(c_str);
168+
let attr = llvm::LLVMCreateStringAttribute(
169+
llcx,
170+
c_attr_name.as_ptr(),
171+
c_attr_name.as_bytes().len() as c_uint,
172+
c_str.as_ptr(),
173+
c_str.to_bytes().len() as c_uint,
174+
);
175+
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
176+
attr
177+
};
178+
unsafe {
179+
llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr);
180+
}
181+
}
139182

140183
fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
141184
let inputs = tt.args;

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -848,8 +848,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
848848
fnc: &Value,
849849
input_diffactivity: Vec<DiffActivity>,
850850
ret_diffactivity: DiffActivity,
851-
input_tts: Vec<TypeTree>,
852-
output_tt: TypeTree,
851+
_input_tts: Vec<TypeTree>,
852+
_output_tt: TypeTree,
853853
void_ret: bool,
854854
) -> (&Value, Vec<usize>) {
855855
let ret_activity = cdiffe_from(ret_diffactivity);
@@ -878,13 +878,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
878878
};
879879
trace!("ret_primary_ret: {}", &ret_primary_ret);
880880

881-
let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
881+
//let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
882882
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
883883

884884
// We don't support volatile / extern / (global?) values.
885885
// Just because I didn't had time to test them, and it seems less urgent.
886-
let args_uncacheable = vec![0; input_tts.len()];
887-
assert!(args_uncacheable.len() == input_activity.len());
886+
let args_uncacheable = vec![0; input_activity.len()];
888887
let num_fnc_args = LLVMCountParams(fnc);
889888
trace!("num_fnc_args: {}", num_fnc_args);
890889
trace!("input_activity.len(): {}", input_activity.len());
@@ -894,9 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
894893

895894
let mut known_values = vec![kv_tmp; input_activity.len()];
896895

896+
let tree_tmp = TypeTree::new();
897+
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];
898+
899+
//let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
900+
//let ret_tt = std::ptr::null_mut();
901+
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
902+
let ret_tt = TypeTree::new();
897903
let dummy_type = CFnTypeInfo {
898904
Arguments: args_tree.as_mut_ptr(),
899-
Return: output_tt.inner.clone(),
905+
Return: ret_tt.inner,
900906
KnownValues: known_values.as_mut_ptr(),
901907
};
902908

@@ -935,7 +941,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
935941
rust_input_activity: Vec<DiffActivity>,
936942
ret_activity: DiffActivity,
937943
input_tts: Vec<TypeTree>,
938-
output_tt: TypeTree,
944+
_output_tt: TypeTree,
939945
) -> (&Value, Vec<usize>) {
940946
let (primary_ret, ret_activity) = match ret_activity {
941947
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
@@ -961,16 +967,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
961967
input_activity.push(cdiffe_from(x));
962968
}
963969

964-
let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
970+
//let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
965971

966972
// We don't support volatile / extern / (global?) values.
967973
// Just because I didn't had time to test them, and it seems less urgent.
968-
let args_uncacheable = vec![0; input_tts.len()];
969-
if args_uncacheable.len() != input_activity.len() {
970-
dbg!("args_uncacheable.len(): {}", args_uncacheable.len());
971-
dbg!("input_activity.len(): {}", input_activity.len());
972-
}
973-
assert!(args_uncacheable.len() == input_activity.len());
974+
let args_uncacheable = vec![0; input_activity.len()];
974975
let num_fnc_args = LLVMCountParams(fnc);
975976
println!("num_fnc_args: {}", num_fnc_args);
976977
println!("input_activity.len(): {}", input_activity.len());
@@ -979,9 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
979980

980981
let mut known_values = vec![kv_tmp; input_tts.len()];
981982

983+
let tree_tmp = TypeTree::new();
984+
let mut args_tree = vec![tree_tmp.inner; input_tts.len()];
985+
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
986+
let ret_tt = TypeTree::new();
987+
//let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
988+
//let ret_tt = std::ptr::null_mut();
982989
let dummy_type = CFnTypeInfo {
983990
Arguments: args_tree.as_mut_ptr(),
984-
Return: output_tt.inner.clone(),
991+
Return: ret_tt.inner,
985992
KnownValues: known_values.as_mut_ptr(),
986993
};
987994

@@ -1023,12 +1030,12 @@ extern "C" {
10231030
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
10241031
// Enzyme
10251032
pub fn LLVMRustAddFncParamAttr<'a>(
1026-
Instr: &'a Value,
1033+
F: &'a Value,
10271034
index: c_uint,
10281035
Attr: &'a Attribute
10291036
);
10301037

1031-
pub fn LLVMRustAddRetAttr(V: &Value, attr: AttributeKind);
1038+
pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute);
10321039
pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind);
10331040
pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool;
10341041
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,9 +865,9 @@ extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i,
865865
}
866866

867867
extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F,
868-
LLVMRustAttribute RustAttr) {
868+
LLVMAttributeRef RustAttr) {
869869
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
870-
Fn->addRetAttr(fromRust(RustAttr));
870+
Fn->addRetAttr(unwrap(RustAttr));
871871
}
872872
}
873873

0 commit comments

Comments
 (0)