Skip to content

Initial PythonCallable implementation #1984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration_tests/lpython_decorator_01.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numpy import array
from lpython import i32, f64, lpython
from lpython import i32, f64, temp_lpython

@lpython
@temp_lpython
def fast_sum(n: i32, x: f64[:]) -> f64:
s: f64 = 0.0
i: i32
Expand Down
168 changes: 165 additions & 3 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,33 @@ R"(#include <stdio.h>
return code;
}

std::string get_type_format(ASR::ttype_t *type) {
// See: https://docs.python.org/3/c-api/arg.html for more info on `type format`
switch (type->type) {
case ASR::ttypeType::Integer: {
int a_kind = ASRUtils::extract_kind_from_ttype_t(type);
if (a_kind == 4) {
return "i";
} else {
return "l";
}
} case ASR::ttypeType::Real : {
int a_kind = ASRUtils::extract_kind_from_ttype_t(type);
if (a_kind == 4) {
return "f";
} else {
return "d";
}
} case ASR::ttypeType::Logical : {
return "p";
} case ASR::ttypeType::Array : {
return "O";
} default: {
throw CodeGenError("CPython type format not supported yet");
}
}
}

void visit_Function(const ASR::Function_t &x) {
current_body = "";
SymbolTable* current_scope_copy = current_scope;
Expand Down Expand Up @@ -714,13 +741,148 @@ R"(#include <stdio.h>
}
sub += "\n";
src = sub;
if (f_type->m_abi == ASR::abiType::BindC
&& f_type->m_deftype == ASR::deftypeType::Implementation) {
if (x.m_module_file) {
if (f_type->m_deftype == ASR::deftypeType::Implementation) {
if (f_type->m_abi == ASR::abiType::BindC && x.m_module_file) {
std::string header_name = std::string(x.m_module_file);
user_headers.insert(header_name);
emit_headers[header_name]+= "\n" + src;
src = "";
} else if (f_type->m_abi == ASR::abiType::BindPython) {
headers.insert("Python.h");
std::string variables_decl = "";
std::string fill_parse_args_details = "";
std::string type_format = "";
std::string fn_args = "";
std::string fill_array_details = "";
std::string numpy_init = "";

for (size_t i = 0; i < x.n_args; i++) {
ASR::Variable_t *arg = ASRUtils::EXPR2VAR(x.m_args[i]);
std::string arg_name = arg->m_name;
fill_parse_args_details += "&" + arg_name;
type_format += get_type_format(arg->m_type);

if (ASR::is_a<ASR::Array_t>(*arg->m_type)) {
if (numpy_init.size() == 0) {
numpy_init = R"(
// Initialize NumPy
import_array();
)";
headers.insert("numpy/ndarrayobject.h");
user_defines.insert("NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION");
}
fn_args += "s_array_" + arg_name;
variables_decl += " PyArrayObject *" + arg_name + ";\n";
std::string c_array_type = self().convert_variable_decl(*arg);
c_array_type = c_array_type.substr(0,
c_array_type.size() - arg_name.size() - 2);
fill_array_details += "\n // fill array details for " + arg_name
+ "\n if (PyArray_NDIM(" + arg_name + R"() != 1) {
PyErr_SetString(PyExc_TypeError, "An error occurred in the `lpython` decorator: "
"Only 1 dimension array is supported for now.");
return NULL;
}

)" + c_array_type + R"( *s_array_)" + arg_name + R"( = malloc(sizeof()" + c_array_type + R"());
{
)" + CUtils::get_c_type_from_ttype_t(arg->m_type) + R"( *array;
// Create C arrays from numpy objects:
PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE()" + arg_name + R"());
npy_intp dims[1];
if (PyArray_AsCArray((PyObject **)&)" + arg_name + R"(, (void *)&array, dims, 1, descr) < 0) {
PyErr_SetString(PyExc_TypeError, "An error occurred in the `lpython` decorator: "
"Failed to create a C array");
return NULL;
}

s_array_)" + arg_name + R"(->data = array;
s_array_)" + arg_name + R"(->n_dims = 1;
s_array_)" + arg_name + R"(->dims[0].lower_bound = 0;
s_array_)" + arg_name + R"(->dims[0].length = dims[0];
s_array_)" + arg_name + R"(->is_allocated = false;
}
)";
} else {
fn_args += arg_name;
variables_decl += " " + self().convert_variable_decl(*arg)
+ ";\n";
}
if (i < x.n_args - 1) {
fill_parse_args_details += ", ";
fn_args += ", ";
}
}

if (fill_parse_args_details.size() > 0) {
fill_parse_args_details = R"(
// Parse the arguments from Python
if (!PyArg_ParseTuple(args, ")" + type_format + R"(", )" + fill_parse_args_details + R"()) {
PyErr_SetString(PyExc_TypeError, "An error occurred in the `lpython` decorator: "
"Failed to parse or receive arguments from Python");
return NULL;
}
)";
}

std::string fn_name = x.m_name;
std::string fill_return_details;
if (variables_decl.size() > 0) {
variables_decl.insert(0, "\n "
"// Declare arguments and return variable\n");
}
if(x.m_return_var) {
ASR::Variable_t *return_var = ASRUtils::EXPR2VAR(x.m_return_var);
variables_decl += " " + self().convert_variable_decl(*return_var)
+ ";\n";
fill_return_details = R"(
// Call the C function
_lpython_return_variable = )" + fn_name + "(" + fn_args + ");\n" + R"(
// Build and return the result as a Python object
return Py_BuildValue(")" + get_type_format(return_var->m_type)
+ R"(", _lpython_return_variable);)";
} else {
fill_return_details = R"(
// Call the C function
)" + fn_name + "(" + fn_args + ");\n" + R"(
// Return None
Py_RETURN_NONE;)";
}
src = sub;
src += R"(// Define the Python module and method mappings
static PyObject* )" + fn_name + R"(_define_module(PyObject* self, PyObject* args) {)"
+ numpy_init + variables_decl + fill_parse_args_details
+ fill_array_details + fill_return_details + R"(
}

// Define the module's method table
static PyMethodDef )" + fn_name + R"(_module_methods[] = {
{")" + fn_name + R"(", )" + fn_name + R"(_define_module, METH_VARARGS,
"Handle arguments & return variable and call the function"},
{NULL, NULL, 0, NULL}
};

// Define the module initialization function
static struct PyModuleDef )" + fn_name + R"(_module_def = {
PyModuleDef_HEAD_INIT,
"lpython_module_)" + fn_name + R"(",
"Shared library to use LPython generated functions",
-1,
)" + fn_name + R"(_module_methods
};

PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
PyObject* module;

// Create the module object
module = PyModule_Create(&)" + fn_name + R"(_module_def);
if (!module) {
return NULL;
}

return module;
}

)";
}
}
current_scope = current_scope_copy;
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/pass/unused_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class CollectUnusedFunctionsVisitor :

void visit_Function(const ASR::Function_t &x) {
uint64_t h = get_hash((ASR::asr_t*)&x);
if (ASRUtils::get_FunctionType(x)->m_abi != ASR::abiType::BindC) {
if (ASRUtils::get_FunctionType(x)->m_abi != ASR::abiType::BindC
&& ASRUtils::get_FunctionType(x)->m_abi != ASR::abiType::BindPython) {
fn_declarations[h] = x.m_name;
}

Expand Down
4 changes: 2 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3950,9 +3950,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
current_procedure_interface = true;
} else if (name == "ccallback" || name == "ccallable") {
current_procedure_abi_type = ASR::abiType::BindC;
} else if (name == "pythoncall") {
} else if (name == "pythoncall" || name == "pythoncallable") {
current_procedure_abi_type = ASR::abiType::BindPython;
current_procedure_interface = true;
current_procedure_interface = (name == "pythoncall");
} else if (name == "overload") {
overload = true;
} else if (name == "interface") {
Expand Down
62 changes: 61 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def get_typenum(t):

# ----------------------------------------------------------------------
# Python wrapper for the Shared library
template = f"""// Python headers
template = f"""// Python C/API headers
#include <Python.h>

// NumPy C/API headers
Expand Down Expand Up @@ -923,3 +923,63 @@ def __call__(self, *args, **kwargs):
function = getattr(__import__("lpython_module_" + self.fn_name),
self.fn_name)
return function(*args, **kwargs)


class temp_lpython:
def __init__(self, function):
def get_rtlib_dir():
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, "..")

self.fn_name = function.__name__
# Get the source code of the function
source_code = getsource(function)
source_code = source_code[source_code.find('\n'):]

dir_name = "./lpython_decorator_" + self.fn_name
if not os.path.exists(dir_name):
os.mkdir(dir_name)
filename = dir_name + "/" + self.fn_name

# Open the file for writing
with open(filename + ".py", "w") as file:
# Write the Python source code to the file
file.write("@pythoncallable")
file.write(source_code)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly!

# ----------------------------------------------------------------------

r = os.system("lpython --show-c --disable-main "
+ filename + ".py > " + filename + ".c")
assert r == 0, "Failed to create C file"

gcc_flags = ""
if platform.system() == "Linux":
gcc_flags = " -shared -fPIC "
elif platform.system() == "Darwin":
gcc_flags = " -bundle -flat_namespace -undefined suppress "
else:
raise NotImplementedError("Platform not implemented")

from numpy import get_include
from distutils.sysconfig import get_python_inc, get_python_lib, \
get_python_version
python_path = "-I" + get_python_inc() + " "
numpy_path = "-I" + get_include() + " "
rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime "
rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \
+ get_rtlib_dir() + " -llpython_runtime "
python_lib = "-L" + get_python_lib() + "/../.. -lpython" + \
get_python_version() + " -lm"


r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
filename + ".c -o lpython_module_" + self.fn_name + ".so " +
rt_path_01 + rt_path_02 + python_lib)
assert r == 0, "Failed to create the shared library"

def __call__(self, *args, **kwargs):
import sys; sys.path.append('.')
# import the symbol from the shared library
function = getattr(__import__("lpython_module_" + self.fn_name),
self.fn_name)
return function(*args, **kwargs)