Skip to content

Commit 67b71d6

Browse files
Create an array using an API
1 parent a9c5ee0 commit 67b71d6

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/runtime/lpython/lpython.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,18 @@ def get_data_type(t):
662662
else:
663663
return t + " "
664664

665+
def get_typenum(t):
666+
if t == "int":
667+
return "NPY_INT"
668+
elif t == "long int":
669+
return "NPY_LONG"
670+
elif t == "float":
671+
return "NPY_FLOAT"
672+
elif t == "double":
673+
return "NPY_DOUBLE"
674+
else:
675+
raise NotImplementedError("Type %s not implemented" % t)
676+
665677
self.fn_name = function.__name__
666678
# Get the source code of the function
667679
source_code = getsource(function)
@@ -789,12 +801,17 @@ def get_data_type(t):
789801
{self.fn_name}({pass_args}, &_{self.fn_name}_return_value[0]);
790802
791803
// Build and return the result as a Python object
792-
PyObject* list_obj = PyList_New({self.array_as_return_type[1][3]});
793-
for (int i = 0; i < {self.array_as_return_type[1][3]}; i++) {{
794-
PyObject* element = PyFloat_FromDouble(_{self.fn_name}_return_value->data[i]);
795-
PyList_SetItem(list_obj, i, element);
796-
}}
797-
return list_obj;"""
804+
{{
805+
npy_intp dims[] = {{{self.array_as_return_type[1][3]}}};
806+
PyObject* numpy_array = PyArray_SimpleNewFromData(1, dims, {
807+
get_typenum(self.array_as_return_type[1][2][:-2])},
808+
_{self.fn_name}_return_value->data);
809+
if (numpy_array == NULL) {{
810+
PyErr_SetString(PyExc_TypeError, "error creating an array");
811+
return NULL;
812+
}}
813+
return numpy_array;
814+
}}"""
798815
else:
799816
fill_return_details = f"""{self.fn_name}({pass_args});
800817
Py_RETURN_NONE;"""
@@ -884,8 +901,4 @@ def __call__(self, *args, **kwargs):
884901
# import the symbol from the shared library
885902
function = getattr(__import__("lpython_module_" + self.fn_name),
886903
self.fn_name)
887-
if self.array_as_return_type:
888-
from numpy import array
889-
return array(function(*args, **kwargs))
890-
else:
891-
return function(*args, **kwargs)
904+
return function(*args, **kwargs)

0 commit comments

Comments
 (0)