@@ -848,8 +848,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
848
848
fnc : & Value ,
849
849
input_diffactivity : Vec < DiffActivity > ,
850
850
ret_diffactivity : DiffActivity ,
851
- input_tts : Vec < TypeTree > ,
852
- output_tt : TypeTree ,
851
+ _input_tts : Vec < TypeTree > ,
852
+ _output_tt : TypeTree ,
853
853
void_ret : bool ,
854
854
) -> ( & Value , Vec < usize > ) {
855
855
let ret_activity = cdiffe_from ( ret_diffactivity) ;
@@ -878,13 +878,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
878
878
} ;
879
879
trace ! ( "ret_primary_ret: {}" , & ret_primary_ret) ;
880
880
881
- let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
881
+ // let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
882
882
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
883
883
884
884
// We don't support volatile / extern / (global?) values.
885
885
// Just because I didn't had time to test them, and it seems less urgent.
886
- let args_uncacheable = vec ! [ 0 ; input_tts. len( ) ] ;
887
- assert ! ( args_uncacheable. len( ) == input_activity. len( ) ) ;
886
+ let args_uncacheable = vec ! [ 0 ; input_activity. len( ) ] ;
888
887
let num_fnc_args = LLVMCountParams ( fnc) ;
889
888
trace ! ( "num_fnc_args: {}" , num_fnc_args) ;
890
889
trace ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
@@ -894,9 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
894
893
895
894
let mut known_values = vec ! [ kv_tmp; input_activity. len( ) ] ;
896
895
896
+ let tree_tmp = TypeTree :: new ( ) ;
897
+ let mut args_tree = vec ! [ tree_tmp. inner; input_activity. len( ) ] ;
898
+
899
+ //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
900
+ //let ret_tt = std::ptr::null_mut();
901
+ //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
902
+ let ret_tt = TypeTree :: new ( ) ;
897
903
let dummy_type = CFnTypeInfo {
898
904
Arguments : args_tree. as_mut_ptr ( ) ,
899
- Return : output_tt . inner . clone ( ) ,
905
+ Return : ret_tt . inner ,
900
906
KnownValues : known_values. as_mut_ptr ( ) ,
901
907
} ;
902
908
@@ -935,7 +941,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
935
941
rust_input_activity : Vec < DiffActivity > ,
936
942
ret_activity : DiffActivity ,
937
943
input_tts : Vec < TypeTree > ,
938
- output_tt : TypeTree ,
944
+ _output_tt : TypeTree ,
939
945
) -> ( & Value , Vec < usize > ) {
940
946
let ( primary_ret, ret_activity) = match ret_activity {
941
947
DiffActivity :: Const => ( true , CDIFFE_TYPE :: DFT_CONSTANT ) ,
@@ -961,16 +967,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
961
967
input_activity. push ( cdiffe_from ( x) ) ;
962
968
}
963
969
964
- let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
970
+ // let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
965
971
966
972
// We don't support volatile / extern / (global?) values.
967
973
// Just because I didn't had time to test them, and it seems less urgent.
968
- let args_uncacheable = vec ! [ 0 ; input_tts. len( ) ] ;
969
- if args_uncacheable. len ( ) != input_activity. len ( ) {
970
- dbg ! ( "args_uncacheable.len(): {}" , args_uncacheable. len( ) ) ;
971
- dbg ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
972
- }
973
- assert ! ( args_uncacheable. len( ) == input_activity. len( ) ) ;
974
+ let args_uncacheable = vec ! [ 0 ; input_activity. len( ) ] ;
974
975
let num_fnc_args = LLVMCountParams ( fnc) ;
975
976
println ! ( "num_fnc_args: {}" , num_fnc_args) ;
976
977
println ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
@@ -979,9 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
979
980
980
981
let mut known_values = vec ! [ kv_tmp; input_tts. len( ) ] ;
981
982
983
+ let tree_tmp = TypeTree :: new ( ) ;
984
+ let mut args_tree = vec ! [ tree_tmp. inner; input_tts. len( ) ] ;
985
+ //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
986
+ let ret_tt = TypeTree :: new ( ) ;
987
+ //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
988
+ //let ret_tt = std::ptr::null_mut();
982
989
let dummy_type = CFnTypeInfo {
983
990
Arguments : args_tree. as_mut_ptr ( ) ,
984
- Return : output_tt . inner . clone ( ) ,
991
+ Return : ret_tt . inner ,
985
992
KnownValues : known_values. as_mut_ptr ( ) ,
986
993
} ;
987
994
@@ -1023,12 +1030,12 @@ extern "C" {
1023
1030
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
1024
1031
// Enzyme
1025
1032
pub fn LLVMRustAddFncParamAttr < ' a > (
1026
- Instr : & ' a Value ,
1033
+ F : & ' a Value ,
1027
1034
index : c_uint ,
1028
1035
Attr : & ' a Attribute
1029
1036
) ;
1030
1037
1031
- pub fn LLVMRustAddRetAttr ( V : & Value , attr : AttributeKind ) ;
1038
+ pub fn LLVMRustAddRetFncAttr ( F : & Value , attr : & Attribute ) ;
1032
1039
pub fn LLVMRustRemoveFncAttr ( V : & Value , attr : AttributeKind ) ;
1033
1040
pub fn LLVMRustHasDbgMetadata ( I : & Value ) -> bool ;
1034
1041
pub fn LLVMRustHasMetadata ( I : & Value , KindID : c_uint ) -> bool ;
0 commit comments