Skip to content

@jit interface for LPython #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ inst/bin/*
*_ldd.txt
*_lines.dat.txt
*__tmp__generated__.c
a.c
a.h
a.py

### https://raw.github.com/github/gitignore/218a941be92679ce67d0484547e3e142b2f5f6f0/Global/macOS.gitignore

Expand Down
3 changes: 3 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,6 @@ RUN(NAME callback_01 LABELS cpython llvm)
RUN(NAME intrinsics_01 LABELS cpython llvm) # any

COMPILE(NAME import_order_01 LABELS cpython llvm c) # any

# Jit
RUN(NAME test_jit_01 LABELS cpython)
16 changes: 16 additions & 0 deletions integration_tests/test_jit_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from numpy import array
from lpython import i32, f64, jit

@jit
def fast_sum(n: i32, x: f64[:]) -> f64:
s: f64 = 0.0
i: i32
for i in range(n):
s += x[i]
return s

def test():
x: f64[3] = array([1.0, 2.0, 3.0])
assert fast_sum(3, x) == 6.0

test()
6 changes: 5 additions & 1 deletion src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3500,9 +3500,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
is_inline = true;
} else if (name == "static") {
is_static = true;
} else if (name == "jit") {
throw SemanticError("`@lpython.jit` decorator must be "
"run from CPython, not compiled using LPython",
dec->base.loc);
} else {
throw SemanticError("Decorator: " + name + " is not supported",
x.base.base.loc);
dec->base.loc);
}
} else if (AST::is_a<AST::Call_t>(*dec)) {
AST::Call_t *call_d = AST::down_cast<AST::Call_t>(dec);
Expand Down
221 changes: 220 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from inspect import getfullargspec, getcallargs, isclass
from inspect import getfullargspec, getcallargs, isclass, getsource
import os
import ctypes
import platform
from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
from goto import with_goto
from numpy import get_include
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want to include numpy here, since it would be imported with every "import lpython" in every file and things will be slow. Python's imports are slow, and we want "import lpython" to be immediate. For now I would "import numpy" when the decorator is processed, not here. This can be done in subsequent PR.

from distutils.sysconfig import get_python_inc

# TODO: this does not seem to restrict other imports
__slots__ = ["i8", "i16", "i32", "i64", "f32", "f64", "c32", "c64", "CPtr",
Expand Down Expand Up @@ -514,3 +516,220 @@ def ccallable(f):

def ccallback(f):
return f

class jit:
def __init__(self, function):
def get_rtlib_dir():
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, "..")

def get_type_info(arg):
# return_type -> (`type_format`, `variable type`, `array struct name`)
# See: https://docs.python.org/3/c-api/arg.html for more info on type_format
if arg == f64:
return ('d', "double", 'r64')
elif arg == f32:
return ('f', "float", 'r32')
elif arg == i64:
return ('l', "long int", 'i64')
elif arg == i32:
return ('i', "int", 'i32')
elif arg == bool:
return ('p', "bool", '')
elif isinstance(arg, Array):
t = get_type_info(arg._type)
if t[2] == '':
raise NotImplementedError("Type %r not implemented" % arg)
return ('O', ["PyArrayObject *", "struct "+t[2]+" *", t[1]+" *"], '')
else:
raise NotImplementedError("Type %r not implemented" % arg)

def get_data_type(t):
if isinstance(t, list):
return t[0]
else:
return t + " "

self.fn_name = function.__name__
# Get the source code of the function
source_code = getsource(function)
source_code = source_code[source_code.find('\n'):]

# TODO: Create a filename based on the function name
# filename = function.__name__ + ".py"

# Open the file for writing
with open("a.py", "w") as file:
# Write the Python source code to the file
file.write("@ccallable")
file.write(source_code)
# ----------------------------------------------------------------------
types = function.__annotations__
self.arg_type_formats = ""
self.return_type = ""
self.return_type_format = ""
self.arg_types = {}
counter = 1
for t in types.keys():
if t == "return":
type = get_type_info(types[t])
self.return_type_format = type[0]
self.return_type = type[1]
else:
type = get_type_info(types[t])
self.arg_type_formats += type[0]
self.arg_types[counter] = type[1]
counter += 1
# ----------------------------------------------------------------------
# `arg_0`: used as the return variables
# arguments are declared as `arg_1`, `arg_2`, ...
variables_decl = ""
if self.return_type != "":
variables_decl = "// Declare return variables and arguments\n"
variables_decl += " " + get_data_type(self.return_type) + "arg_" \
+ str(0) + ";\n"
# ----------------------------------------------------------------------
# `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
# `fill_array_details` contains arrays operations to be
# performed on the arguments
# `parse_args` are used to capture the args from CPython
# `pass_args` are the args that are passed to the shared library function
fill_array_details = ""
parse_args = ""
pass_args = ""
numpy_init = ""
for i, t in self.arg_types.items():
if i > 1:
parse_args += ", "
pass_args += ", "
if isinstance(t, list):
if numpy_init == "":
numpy_init = "// Initialize NumPy\n import_array();\n\n "
fill_array_details += f"""\n
// fill array details for args[{i-1}]
if (PyArray_NDIM(arg_{i}) != 1) {{
PyErr_SetString(PyExc_TypeError,
"Only 1 dimension is implemented for now.");
return NULL;
}}

{t[1]}s_array_{i} = malloc(sizeof(struct r64));
{{
{t[2]}array;
// Create C arrays from numpy objects:
PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{i}));
npy_intp dims[1];
if (PyArray_AsCArray((PyObject **)&arg_{i}, (void *)&array, dims, 1, descr) < 0) {{
PyErr_SetString(PyExc_TypeError, "error converting to c array");
return NULL;
}}

s_array_{i}->data = array;
s_array_{i}->n_dims = 1;
s_array_{i}->dims[0].lower_bound = 0;
s_array_{i}->dims[0].length = dims[0];
s_array_{i}->is_allocated = false;
}}"""
pass_args += "s_array_" + str(i)
else:
pass_args += "arg_" + str(i)
variables_decl += " " + get_data_type(t) + "arg_" + str(i) + ";\n"
parse_args += "&arg_" + str(i)

if parse_args != "":
parse_args = f"""\n // Parse the arguments from Python
if (!PyArg_ParseTuple(args, "{self.arg_type_formats}", {parse_args})) {{
return NULL;
}}"""

# ----------------------------------------------------------------------
# Handle the return variable if any; otherwise, return None
fill_return_details = ""
if self.return_type != "":
fill_return_details = f"""\n\n // Call the C function
arg_0 = {self.fn_name}({pass_args});

// Build and return the result as a Python object
return Py_BuildValue("{self.return_type_format}", arg_0);"""
else:
fill_return_details = f"""{self.fn_name}({pass_args});
Py_RETURN_NONE;"""

# ----------------------------------------------------------------------
# Python wrapper for the Shared library
template = f"""// Python headers
#include <Python.h>

// NumPy C/API headers
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings
#include <numpy/ndarrayobject.h>

// LPython generated C code
#include "a.h"

// Define the Python module and method mappings
static PyObject* define_module(PyObject* self, PyObject* args) {{
{numpy_init}{variables_decl}{parse_args}\
{fill_array_details}{fill_return_details}
}}

// Define the module's method table
static PyMethodDef module_methods[] = {{
{{"{self.fn_name}", define_module, METH_VARARGS,
"Handle arguments & return variable and call the function"}},
{{NULL, NULL, 0, NULL}}
}};

// Define the module initialization function
static struct PyModuleDef module_def = {{
PyModuleDef_HEAD_INIT,
"lpython_jit_module",
"Shared library to use LPython generated functions",
-1,
module_methods
}};

PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
PyObject* module;

// Create the module object
module = PyModule_Create(&module_def);
if (!module) {{
return NULL;
}}

return module;
}}
"""
# ----------------------------------------------------------------------
# Write the C source code to the file
with open("a.c", "w") as file:
file.write(template)

# ----------------------------------------------------------------------
# Generate the Shared library
# TODO: Use LLVM instead of C backend
r = os.system("lpython --show-c --disable-main a.py > a.h")
assert r == 0, "Failed to create C file"
gcc_flags = ""
if platform.system() == "Linux":
gcc_flags = " -shared -fPIC "
elif platform.system() == "Darwin":
gcc_flags = " -bundle -flat_namespace -undefined suppress "
else:
raise NotImplementedError("Platform not implemented")
python_path = "-I" + get_python_inc() + " "
numpy_path = "-I" + get_include()
rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime "
rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \
+ get_rtlib_dir() + " -llpython_runtime "
python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm"
r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
" a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib)
assert r == 0, "Failed to create the shared library"

def __call__(self, *args, **kwargs):
import sys; sys.path.append('.')
# import the symbol from the shared library
function = getattr(__import__("lpython_jit_module"), self.fn_name)
return function(*args, **kwargs)