@@ -666,6 +666,18 @@ def get_data_type(t):
666
666
else :
667
667
return t + " "
668
668
669
+ def get_typenum (t ):
670
+ if t == "int" :
671
+ return "NPY_INT"
672
+ elif t == "long int" :
673
+ return "NPY_LONG"
674
+ elif t == "float" :
675
+ return "NPY_FLOAT"
676
+ elif t == "double" :
677
+ return "NPY_DOUBLE"
678
+ else :
679
+ raise NotImplementedError ("Type %s not implemented" % t )
680
+
669
681
self .fn_name = function .__name__
670
682
# Get the source code of the function
671
683
source_code = getsource (function )
@@ -793,12 +805,17 @@ def get_data_type(t):
793
805
{ self .fn_name } ({ pass_args } , &_{ self .fn_name } _return_value[0]);
794
806
795
807
// Build and return the result as a Python object
796
- PyObject* list_obj = PyList_New({ self .array_as_return_type [1 ][3 ]} );
797
- for (int i = 0; i < { self .array_as_return_type [1 ][3 ]} ; i++) {{
798
- PyObject* element = PyFloat_FromDouble(_{ self .fn_name } _return_value->data[i]);
799
- PyList_SetItem(list_obj, i, element);
800
- }}
801
- return list_obj;"""
808
+ {{
809
+ npy_intp dims[] = {{{ self .array_as_return_type [1 ][3 ]} }};
810
+ PyObject* numpy_array = PyArray_SimpleNewFromData(1, dims, {
811
+ get_typenum (self .array_as_return_type [1 ][2 ][:- 2 ])} ,
812
+ _{ self .fn_name } _return_value->data);
813
+ if (numpy_array == NULL) {{
814
+ PyErr_SetString(PyExc_TypeError, "error creating an array");
815
+ return NULL;
816
+ }}
817
+ return numpy_array;
818
+ }}"""
802
819
else :
803
820
fill_return_details = f"""{ self .fn_name } ({ pass_args } );
804
821
Py_RETURN_NONE;"""
@@ -890,8 +907,4 @@ def __call__(self, *args, **kwargs):
890
907
# import the symbol from the shared library
891
908
function = getattr (__import__ ("lpython_module_" + self .fn_name ),
892
909
self .fn_name )
893
- if self .array_as_return_type :
894
- from numpy import array
895
- return array (function (* args , ** kwargs ))
896
- else :
897
- return function (* args , ** kwargs )
910
+ return function (* args , ** kwargs )
0 commit comments