@@ -637,7 +637,8 @@ def get_rtlib_dir():
637
637
638
638
def get_type_info (arg ):
639
639
# return_type -> (`type_format`, `variable type`, `array struct name`)
640
- # See: https://docs.python.org/3/c-api/arg.html for more info on type_format
640
+ # See: https://docs.python.org/3/c-api/arg.html for more info on `type_format`
641
+ # `array struct name`: used by the C backend
641
642
if arg == f64 :
642
643
return ('d' , "double" , 'r64' )
643
644
elif arg == f32 :
@@ -652,7 +653,10 @@ def get_type_info(arg):
652
653
t = get_type_info (arg ._type )
653
654
if t [2 ] == '' :
654
655
raise NotImplementedError ("Type %r not implemented" % arg )
655
- return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" ], '' )
656
+ n = ''
657
+ if not isinstance (arg ._dims , slice ):
658
+ n = arg ._dims ._name
659
+ return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" , n ], '' )
656
660
else :
657
661
raise NotImplementedError ("Type %r not implemented" % arg )
658
662
@@ -662,6 +666,18 @@ def get_data_type(t):
662
666
else :
663
667
return t + " "
664
668
669
+ def get_typenum (t ):
670
+ if t == "int" :
671
+ return "NPY_INT"
672
+ elif t == "long int" :
673
+ return "NPY_LONG"
674
+ elif t == "float" :
675
+ return "NPY_FLOAT"
676
+ elif t == "double" :
677
+ return "NPY_DOUBLE"
678
+ else :
679
+ raise NotImplementedError ("Type %s not implemented" % t )
680
+
665
681
self .fn_name = function .__name__
666
682
# Get the source code of the function
667
683
source_code = getsource (function )
@@ -682,46 +698,55 @@ def get_data_type(t):
682
698
self .arg_type_formats = ""
683
699
self .return_type = ""
684
700
self .return_type_format = ""
701
+ self .array_as_return_type = ()
685
702
self .arg_types = {}
686
- counter = 1
687
703
for t in types .keys ():
688
704
if t == "return" :
689
705
type = get_type_info (types [t ])
690
- self .return_type_format = type [0 ]
691
- self .return_type = type [1 ]
706
+ if type [0 ] == 'O' :
707
+ self .array_as_return_type = type
708
+ continue
709
+ else :
710
+ self .return_type_format = type [0 ]
711
+ self .return_type = type [1 ]
692
712
else :
693
713
type = get_type_info (types [t ])
694
714
self .arg_type_formats += type [0 ]
695
- self .arg_types [counter ] = type [1 ]
696
- counter += 1
715
+ self .arg_types [t ] = type [1 ]
697
716
# ----------------------------------------------------------------------
698
- # `arg_0`: used as the return variables
699
- # arguments are declared as `arg_1`, `arg_2`, ...
700
- variables_decl = ""
717
+ # `_<fn_name>_return_value`: used as the return variables
718
+ variables_decl = "// Declare return variables and arguments\n "
701
719
if self .return_type != "" :
702
- variables_decl = "// Declare return variables and arguments\n "
703
- variables_decl += " " + get_data_type (self .return_type ) + "arg_" \
704
- + str (0 ) + ";\n "
720
+ variables_decl += " " + get_data_type (self .return_type ) \
721
+ + "_" + self .fn_name + "_return_value;\n "
722
+ elif self .array_as_return_type :
723
+ variables_decl += " " + self .array_as_return_type [1 ][1 ] + "_" \
724
+ + self .fn_name + "_return_value = malloc(sizeof(" \
725
+ + self .array_as_return_type [1 ][1 ][:- 2 ] + "));\n "
726
+ else :
727
+ variables_decl = ""
705
728
# ----------------------------------------------------------------------
706
729
# `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
707
- # `fill_array_details` contains arrays operations to be
730
+ # `fill_array_details` contains array operations to be
708
731
# performed on the arguments
709
732
# `parse_args` are used to capture the args from CPython
710
733
# `pass_args` are the args that are passed to the shared library function
711
734
fill_array_details = ""
712
735
parse_args = ""
713
736
pass_args = ""
714
737
numpy_init = ""
738
+ prefix_comma = False
715
739
for i , t in self .arg_types .items ():
716
- if i > 1 :
740
+ if prefix_comma :
717
741
parse_args += ", "
718
742
pass_args += ", "
743
+ prefix_comma = True
719
744
if isinstance (t , list ):
720
745
if numpy_init == "" :
721
746
numpy_init = "// Initialize NumPy\n import_array();\n \n "
722
747
fill_array_details += f"""\n
723
- // fill array details for args[ { i - 1 } ]
724
- if (PyArray_NDIM(arg_ { i } ) != 1) {{
748
+ // fill array details for { i }
749
+ if (PyArray_NDIM({ i } ) != 1) {{
725
750
PyErr_SetString(PyExc_TypeError,
726
751
"Only 1 dimension is implemented for now.");
727
752
return NULL;
@@ -731,9 +756,9 @@ def get_data_type(t):
731
756
{{
732
757
{ t [2 ]} array;
733
758
// Create C arrays from numpy objects:
734
- PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_ { i } ));
759
+ PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE({ i } ));
735
760
npy_intp dims[1];
736
- if (PyArray_AsCArray((PyObject **)&arg_ { i } , (void *)&array, dims, 1, descr) < 0) {{
761
+ if (PyArray_AsCArray((PyObject **)&{ i } , (void *)&array, dims, 1, descr) < 0) {{
737
762
PyErr_SetString(PyExc_TypeError, "error converting to c array");
738
763
return NULL;
739
764
}}
@@ -744,11 +769,11 @@ def get_data_type(t):
744
769
s_array_{ i } ->dims[0].length = dims[0];
745
770
s_array_{ i } ->is_allocated = false;
746
771
}}"""
747
- pass_args += "s_array_" + str ( i )
772
+ pass_args += "s_array_" + i
748
773
else :
749
- pass_args += "arg_" + str ( i )
750
- variables_decl += " " + get_data_type (t ) + "arg_" + str ( i ) + ";\n "
751
- parse_args += "&arg_ " + str ( i )
774
+ pass_args += i
775
+ variables_decl += " " + get_data_type (t ) + i + ";\n "
776
+ parse_args += "&" + i
752
777
753
778
if parse_args != "" :
754
779
parse_args = f"""\n // Parse the arguments from Python
@@ -761,12 +786,38 @@ def get_data_type(t):
761
786
fill_return_details = ""
762
787
if self .return_type != "" :
763
788
fill_return_details = f"""\n \n // Call the C function
764
- arg_0 = { self .fn_name } ({ pass_args } );
789
+ _ { self . fn_name } _return_value = { self .fn_name } ({ pass_args } );
765
790
766
791
// Build and return the result as a Python object
767
- return Py_BuildValue("{ self .return_type_format } ", arg_0 );"""
792
+ return Py_BuildValue("{ self .return_type_format } ", _ { self . fn_name } _return_value );"""
768
793
else :
769
- fill_return_details = f"""{ self .fn_name } ({ pass_args } );
794
+ if self .array_as_return_type :
795
+ fill_return_details = f"""\n
796
+ _{ self .fn_name } _return_value->data = malloc({ self .array_as_return_type [1 ][3 ]
797
+ } * sizeof({ self .array_as_return_type [1 ][2 ][:- 2 ]} ));
798
+ _{ self .fn_name } _return_value->n_dims = 1;
799
+ _{ self .fn_name } _return_value->dims[0].lower_bound = 0;
800
+ _{ self .fn_name } _return_value->dims[0].length = {
801
+ self .array_as_return_type [1 ][3 ]} ;
802
+ _{ self .fn_name } _return_value->is_allocated = false;
803
+
804
+ // Call the C function
805
+ { self .fn_name } ({ pass_args } , &_{ self .fn_name } _return_value[0]);
806
+
807
+ // Build and return the result as a Python object
808
+ {{
809
+ npy_intp dims[] = {{{ self .array_as_return_type [1 ][3 ]} }};
810
+ PyObject* numpy_array = PyArray_SimpleNewFromData(1, dims, {
811
+ get_typenum (self .array_as_return_type [1 ][2 ][:- 2 ])} ,
812
+ _{ self .fn_name } _return_value->data);
813
+ if (numpy_array == NULL) {{
814
+ PyErr_SetString(PyExc_TypeError, "error creating an array");
815
+ return NULL;
816
+ }}
817
+ return numpy_array;
818
+ }}"""
819
+ else :
820
+ fill_return_details = f"""{ self .fn_name } ({ pass_args } );
770
821
Py_RETURN_NONE;"""
771
822
772
823
# ----------------------------------------------------------------------
0 commit comments