@@ -848,8 +848,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
848848 fnc : & Value ,
849849 input_diffactivity : Vec < DiffActivity > ,
850850 ret_diffactivity : DiffActivity ,
851- input_tts : Vec < TypeTree > ,
852- output_tt : TypeTree ,
851+ _input_tts : Vec < TypeTree > ,
852+ _output_tt : TypeTree ,
853853 void_ret : bool ,
854854) -> ( & Value , Vec < usize > ) {
855855 let ret_activity = cdiffe_from ( ret_diffactivity) ;
@@ -878,13 +878,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
878878 } ;
879879 trace ! ( "ret_primary_ret: {}" , & ret_primary_ret) ;
880880
881- let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
881+ // let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
882882 //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
883883
884884 // We don't support volatile / extern / (global?) values.
885885 // Just because I didn't had time to test them, and it seems less urgent.
886- let args_uncacheable = vec ! [ 0 ; input_tts. len( ) ] ;
887- assert ! ( args_uncacheable. len( ) == input_activity. len( ) ) ;
886+ let args_uncacheable = vec ! [ 0 ; input_activity. len( ) ] ;
888887 let num_fnc_args = LLVMCountParams ( fnc) ;
889888 trace ! ( "num_fnc_args: {}" , num_fnc_args) ;
890889 trace ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
@@ -894,9 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
894893
895894 let mut known_values = vec ! [ kv_tmp; input_activity. len( ) ] ;
896895
896+ let tree_tmp = TypeTree :: new ( ) ;
897+ let mut args_tree = vec ! [ tree_tmp. inner; input_activity. len( ) ] ;
898+
899+ //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
900+ //let ret_tt = std::ptr::null_mut();
901+ //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
902+ let ret_tt = TypeTree :: new ( ) ;
897903 let dummy_type = CFnTypeInfo {
898904 Arguments : args_tree. as_mut_ptr ( ) ,
899- Return : output_tt . inner . clone ( ) ,
905+ Return : ret_tt . inner ,
900906 KnownValues : known_values. as_mut_ptr ( ) ,
901907 } ;
902908
@@ -935,7 +941,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
935941 rust_input_activity : Vec < DiffActivity > ,
936942 ret_activity : DiffActivity ,
937943 input_tts : Vec < TypeTree > ,
938- output_tt : TypeTree ,
944+ _output_tt : TypeTree ,
939945) -> ( & Value , Vec < usize > ) {
940946 let ( primary_ret, ret_activity) = match ret_activity {
941947 DiffActivity :: Const => ( true , CDIFFE_TYPE :: DFT_CONSTANT ) ,
@@ -961,16 +967,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
961967 input_activity. push ( cdiffe_from ( x) ) ;
962968 }
963969
964- let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
970+ // let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
965971
966972 // We don't support volatile / extern / (global?) values.
967973 // Just because I didn't had time to test them, and it seems less urgent.
968- let args_uncacheable = vec ! [ 0 ; input_tts. len( ) ] ;
969- if args_uncacheable. len ( ) != input_activity. len ( ) {
970- dbg ! ( "args_uncacheable.len(): {}" , args_uncacheable. len( ) ) ;
971- dbg ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
972- }
973- assert ! ( args_uncacheable. len( ) == input_activity. len( ) ) ;
974+ let args_uncacheable = vec ! [ 0 ; input_activity. len( ) ] ;
974975 let num_fnc_args = LLVMCountParams ( fnc) ;
975976 println ! ( "num_fnc_args: {}" , num_fnc_args) ;
976977 println ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
@@ -979,9 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
979980
980981 let mut known_values = vec ! [ kv_tmp; input_tts. len( ) ] ;
981982
983+ let tree_tmp = TypeTree :: new ( ) ;
984+ let mut args_tree = vec ! [ tree_tmp. inner; input_tts. len( ) ] ;
985+ //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
986+ let ret_tt = TypeTree :: new ( ) ;
987+ //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
988+ //let ret_tt = std::ptr::null_mut();
982989 let dummy_type = CFnTypeInfo {
983990 Arguments : args_tree. as_mut_ptr ( ) ,
984- Return : output_tt . inner . clone ( ) ,
991+ Return : ret_tt . inner ,
985992 KnownValues : known_values. as_mut_ptr ( ) ,
986993 } ;
987994
@@ -1023,12 +1030,12 @@ extern "C" {
10231030 //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
10241031 // Enzyme
10251032 pub fn LLVMRustAddFncParamAttr < ' a > (
1026- Instr : & ' a Value ,
1033+ F : & ' a Value ,
10271034 index : c_uint ,
10281035 Attr : & ' a Attribute
10291036 ) ;
10301037
1031- pub fn LLVMRustAddRetAttr ( V : & Value , attr : AttributeKind ) ;
1038+ pub fn LLVMRustAddRetFncAttr ( F : & Value , attr : & Attribute ) ;
10321039 pub fn LLVMRustRemoveFncAttr ( V : & Value , attr : AttributeKind ) ;
10331040 pub fn LLVMRustHasDbgMetadata ( I : & Value ) -> bool ;
10341041 pub fn LLVMRustHasMetadata ( I : & Value , KindID : c_uint ) -> bool ;
0 commit comments