@@ -637,7 +637,8 @@ def get_rtlib_dir():
637637
638638 def get_type_info (arg ):
639639 # return_type -> (`type_format`, `variable type`, `array struct name`)
640- # See: https://docs.python.org/3/c-api/arg.html for more info on type_format
640+ # See: https://docs.python.org/3/c-api/arg.html for more info on `type_format`
641+ # `array struct name`: used by the C backend
641642 if arg == f64 :
642643 return ('d' , "double" , 'r64' )
643644 elif arg == f32 :
@@ -652,7 +653,10 @@ def get_type_info(arg):
652653 t = get_type_info (arg ._type )
653654 if t [2 ] == '' :
654655 raise NotImplementedError ("Type %r not implemented" % arg )
655- return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" ], '' )
656+ n = ''
657+ if not isinstance (arg ._dims , slice ):
658+ n = arg ._dims ._name
659+ return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" , n ], '' )
656660 else :
657661 raise NotImplementedError ("Type %r not implemented" % arg )
658662
@@ -662,6 +666,18 @@ def get_data_type(t):
662666 else :
663667 return t + " "
664668
669+ def get_typenum (t ):
670+ if t == "int" :
671+ return "NPY_INT"
672+ elif t == "long int" :
673+ return "NPY_LONG"
674+ elif t == "float" :
675+ return "NPY_FLOAT"
676+ elif t == "double" :
677+ return "NPY_DOUBLE"
678+ else :
679+ raise NotImplementedError ("Type %s not implemented" % t )
680+
665681 self .fn_name = function .__name__
666682 # Get the source code of the function
667683 source_code = getsource (function )
@@ -682,46 +698,55 @@ def get_data_type(t):
682698 self .arg_type_formats = ""
683699 self .return_type = ""
684700 self .return_type_format = ""
701+ self .array_as_return_type = ()
685702 self .arg_types = {}
686- counter = 1
687703 for t in types .keys ():
688704 if t == "return" :
689705 type = get_type_info (types [t ])
690- self .return_type_format = type [0 ]
691- self .return_type = type [1 ]
706+ if type [0 ] == 'O' :
707+ self .array_as_return_type = type
708+ continue
709+ else :
710+ self .return_type_format = type [0 ]
711+ self .return_type = type [1 ]
692712 else :
693713 type = get_type_info (types [t ])
694714 self .arg_type_formats += type [0 ]
695- self .arg_types [counter ] = type [1 ]
696- counter += 1
715+ self .arg_types [t ] = type [1 ]
697716 # ----------------------------------------------------------------------
698- # `arg_0`: used as the return variables
699- # arguments are declared as `arg_1`, `arg_2`, ...
700- variables_decl = ""
717+ # `_<fn_name>_return_value`: used as the return variables
718+ variables_decl = "// Declare return variables and arguments\n "
701719 if self .return_type != "" :
702- variables_decl = "// Declare return variables and arguments\n "
703- variables_decl += " " + get_data_type (self .return_type ) + "arg_" \
704- + str (0 ) + ";\n "
720+ variables_decl += " " + get_data_type (self .return_type ) \
721+ + "_" + self .fn_name + "_return_value;\n "
722+ elif self .array_as_return_type :
723+ variables_decl += " " + self .array_as_return_type [1 ][1 ] + "_" \
724+ + self .fn_name + "_return_value = malloc(sizeof(" \
725+ + self .array_as_return_type [1 ][1 ][:- 2 ] + "));\n "
726+ else :
727+ variables_decl = ""
705728 # ----------------------------------------------------------------------
706729 # `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
707- # `fill_array_details` contains arrays operations to be
730+ # `fill_array_details` contains array operations to be
708731 # performed on the arguments
709732 # `parse_args` are used to capture the args from CPython
710733 # `pass_args` are the args that are passed to the shared library function
711734 fill_array_details = ""
712735 parse_args = ""
713736 pass_args = ""
714737 numpy_init = ""
738+ prefix_comma = False
715739 for i , t in self .arg_types .items ():
716- if i > 1 :
740+ if prefix_comma :
717741 parse_args += ", "
718742 pass_args += ", "
743+ prefix_comma = True
719744 if isinstance (t , list ):
720745 if numpy_init == "" :
721746 numpy_init = "// Initialize NumPy\n import_array();\n \n "
722747 fill_array_details += f"""\n
723- // fill array details for args[ { i - 1 } ]
724- if (PyArray_NDIM(arg_ { i } ) != 1) {{
748+ // fill array details for { i }
749+ if (PyArray_NDIM({ i } ) != 1) {{
725750 PyErr_SetString(PyExc_TypeError,
726751 "Only 1 dimension is implemented for now.");
727752 return NULL;
@@ -731,9 +756,9 @@ def get_data_type(t):
731756 {{
732757 { t [2 ]} array;
733758 // Create C arrays from numpy objects:
734- PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_ { i } ));
759+ PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE({ i } ));
735760 npy_intp dims[1];
736- if (PyArray_AsCArray((PyObject **)&arg_ { i } , (void *)&array, dims, 1, descr) < 0) {{
761+ if (PyArray_AsCArray((PyObject **)&{ i } , (void *)&array, dims, 1, descr) < 0) {{
737762 PyErr_SetString(PyExc_TypeError, "error converting to c array");
738763 return NULL;
739764 }}
@@ -744,11 +769,11 @@ def get_data_type(t):
744769 s_array_{ i } ->dims[0].length = dims[0];
745770 s_array_{ i } ->is_allocated = false;
746771 }}"""
747- pass_args += "s_array_" + str ( i )
772+ pass_args += "s_array_" + i
748773 else :
749- pass_args += "arg_" + str ( i )
750- variables_decl += " " + get_data_type (t ) + "arg_" + str ( i ) + ";\n "
751- parse_args += "&arg_ " + str ( i )
774+ pass_args += i
775+ variables_decl += " " + get_data_type (t ) + i + ";\n "
776+ parse_args += "&" + i
752777
753778 if parse_args != "" :
754779 parse_args = f"""\n // Parse the arguments from Python
@@ -761,12 +786,38 @@ def get_data_type(t):
761786 fill_return_details = ""
762787 if self .return_type != "" :
763788 fill_return_details = f"""\n \n // Call the C function
764- arg_0 = { self .fn_name } ({ pass_args } );
789+ _ { self . fn_name } _return_value = { self .fn_name } ({ pass_args } );
765790
766791 // Build and return the result as a Python object
767- return Py_BuildValue("{ self .return_type_format } ", arg_0 );"""
792+ return Py_BuildValue("{ self .return_type_format } ", _ { self . fn_name } _return_value );"""
768793 else :
769- fill_return_details = f"""{ self .fn_name } ({ pass_args } );
794+ if self .array_as_return_type :
795+ fill_return_details = f"""\n
796+ _{ self .fn_name } _return_value->data = malloc({ self .array_as_return_type [1 ][3 ]
797+ } * sizeof({ self .array_as_return_type [1 ][2 ][:- 2 ]} ));
798+ _{ self .fn_name } _return_value->n_dims = 1;
799+ _{ self .fn_name } _return_value->dims[0].lower_bound = 0;
800+ _{ self .fn_name } _return_value->dims[0].length = {
801+ self .array_as_return_type [1 ][3 ]} ;
802+ _{ self .fn_name } _return_value->is_allocated = false;
803+
804+ // Call the C function
805+ { self .fn_name } ({ pass_args } , &_{ self .fn_name } _return_value[0]);
806+
807+ // Build and return the result as a Python object
808+ {{
809+ npy_intp dims[] = {{{ self .array_as_return_type [1 ][3 ]} }};
810+ PyObject* numpy_array = PyArray_SimpleNewFromData(1, dims, {
811+ get_typenum (self .array_as_return_type [1 ][2 ][:- 2 ])} ,
812+ _{ self .fn_name } _return_value->data);
813+ if (numpy_array == NULL) {{
814+ PyErr_SetString(PyExc_TypeError, "error creating an array");
815+ return NULL;
816+ }}
817+ return numpy_array;
818+ }}"""
819+ else :
820+ fill_return_details = f"""{ self .fn_name } ({ pass_args } );
770821 Py_RETURN_NONE;"""
771822
772823 # ----------------------------------------------------------------------
0 commit comments