@@ -662,6 +662,18 @@ def get_data_type(t):
662
662
else :
663
663
return t + " "
664
664
665
+ def get_typenum (t ):
666
+ if t == "int" :
667
+ return "NPY_INT"
668
+ elif t == "long int" :
669
+ return "NPY_LONG"
670
+ elif t == "float" :
671
+ return "NPY_FLOAT"
672
+ elif t == "double" :
673
+ return "NPY_DOUBLE"
674
+ else :
675
+ raise NotImplementedError ("Type %s not implemented" % t )
676
+
665
677
self .fn_name = function .__name__
666
678
# Get the source code of the function
667
679
source_code = getsource (function )
@@ -789,12 +801,17 @@ def get_data_type(t):
789
801
{ self .fn_name } ({ pass_args } , &_{ self .fn_name } _return_value[0]);
790
802
791
803
// Build and return the result as a Python object
792
- PyObject* list_obj = PyList_New({ self .array_as_return_type [1 ][3 ]} );
793
- for (int i = 0; i < { self .array_as_return_type [1 ][3 ]} ; i++) {{
794
- PyObject* element = PyFloat_FromDouble(_{ self .fn_name } _return_value->data[i]);
795
- PyList_SetItem(list_obj, i, element);
796
- }}
797
- return list_obj;"""
804
+ {{
805
+ npy_intp dims[] = {{{ self .array_as_return_type [1 ][3 ]} }};
806
+ PyObject* numpy_array = PyArray_SimpleNewFromData(1, dims, {
807
+ get_typenum (self .array_as_return_type [1 ][2 ][:- 2 ])} ,
808
+ _{ self .fn_name } _return_value->data);
809
+ if (numpy_array == NULL) {{
810
+ PyErr_SetString(PyExc_TypeError, "error creating an array");
811
+ return NULL;
812
+ }}
813
+ return numpy_array;
814
+ }}"""
798
815
else :
799
816
fill_return_details = f"""{ self .fn_name } ({ pass_args } );
800
817
Py_RETURN_NONE;"""
@@ -884,8 +901,4 @@ def __call__(self, *args, **kwargs):
884
901
# import the symbol from the shared library
885
902
function = getattr (__import__ ("lpython_module_" + self .fn_name ),
886
903
self .fn_name )
887
- if self .array_as_return_type :
888
- from numpy import array
889
- return array (function (* args , ** kwargs ))
890
- else :
891
- return function (* args , ** kwargs )
904
+ return function (* args , ** kwargs )
0 commit comments