diff --git a/.gitignore b/.gitignore index d9a4892a9b..44c7866bae 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,9 @@ inst/bin/* *_ldd.txt *_lines.dat.txt *__tmp__generated__.c +a.c +a.h +a.py ### https://raw.github.com/github/gitignore/218a941be92679ce67d0484547e3e142b2f5f6f0/Global/macOS.gitignore diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index b3ab956052..deab64d41a 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -534,3 +534,6 @@ RUN(NAME callback_01 LABELS cpython llvm) RUN(NAME intrinsics_01 LABELS cpython llvm) # any COMPILE(NAME import_order_01 LABELS cpython llvm c) # any + +# Jit +RUN(NAME test_jit_01 LABELS cpython) diff --git a/integration_tests/test_jit_01.py b/integration_tests/test_jit_01.py new file mode 100644 index 0000000000..a92bd63e9a --- /dev/null +++ b/integration_tests/test_jit_01.py @@ -0,0 +1,16 @@ +from numpy import array +from lpython import i32, f64, jit + +@jit +def fast_sum(n: i32, x: f64[:]) -> f64: + s: f64 = 0.0 + i: i32 + for i in range(n): + s += x[i] + return s + +def test(): + x: f64[3] = array([1.0, 2.0, 3.0]) + assert fast_sum(3, x) == 6.0 + +test() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 1d78ee6abd..f0fb38b7d6 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -3500,9 +3500,13 @@ class SymbolTableVisitor : public CommonVisitor { is_inline = true; } else if (name == "static") { is_static = true; + } else if (name == "jit") { + throw SemanticError("`@lpython.jit` decorator must be " + "run from CPython, not compiled using LPython", + dec->base.loc); } else { throw SemanticError("Decorator: " + name + " is not supported", - x.base.base.loc); + dec->base.loc); } } else if (AST::is_a(*dec)) { AST::Call_t *call_d = AST::down_cast(dec); diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 1456b4b4b5..2fb93681a9 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -1,9 +1,11 @@ -from inspect import getfullargspec, getcallargs, isclass +from inspect import getfullargspec, getcallargs, isclass, getsource import os import ctypes import platform from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass from goto import with_goto +from numpy import get_include +from distutils.sysconfig import get_python_inc # TODO: this does not seem to restrict other imports __slots__ = ["i8", "i16", "i32", "i64", "f32", "f64", "c32", "c64", "CPtr", @@ -514,3 +516,220 @@ def ccallable(f): def ccallback(f): return f + +class jit: + def __init__(self, function): + def get_rtlib_dir(): + current_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(current_dir, "..") + + def get_type_info(arg): + # return_type -> (`type_format`, `variable type`, `array struct name`) + # See: https://docs.python.org/3/c-api/arg.html for more info on type_format + if arg == f64: + return ('d', "double", 'r64') + elif arg == f32: + return ('f', "float", 'r32') + elif arg == i64: + return ('l', "long int", 'i64') + elif arg == i32: + return ('i', "int", 'i32') + elif arg == bool: + return ('p', "bool", '') + elif isinstance(arg, Array): + t = get_type_info(arg._type) + if t[2] == '': + raise NotImplementedError("Type %r not implemented" % arg) + return ('O', ["PyArrayObject *", "struct "+t[2]+" *", t[1]+" *"], '') + else: + raise NotImplementedError("Type %r not implemented" % arg) + + def get_data_type(t): + if isinstance(t, list): + return t[0] + else: + return t + " " + + self.fn_name = function.__name__ + # Get the source code of the function + source_code = getsource(function) + source_code = source_code[source_code.find('\n'):] + + # TODO: Create a filename based on the function name + # filename = function.__name__ + ".py" + + # Open the file for writing + with open("a.py", "w") as file: + # Write the Python source code to the file + file.write("@ccallable") + file.write(source_code) + # ---------------------------------------------------------------------- + types = function.__annotations__ + self.arg_type_formats = "" + self.return_type = "" + self.return_type_format = "" + self.arg_types = {} + counter = 1 + for t in types.keys(): + if t == "return": + type = get_type_info(types[t]) + self.return_type_format = type[0] + self.return_type = type[1] + else: + type = get_type_info(types[t]) + self.arg_type_formats += type[0] + self.arg_types[counter] = type[1] + counter += 1 + # ---------------------------------------------------------------------- + # `arg_0`: used as the return variables + # arguments are declared as `arg_1`, `arg_2`, ... + variables_decl = "" + if self.return_type != "": + variables_decl = "// Declare return variables and arguments\n" + variables_decl += " " + get_data_type(self.return_type) + "arg_" \ + + str(0) + ";\n" + # ---------------------------------------------------------------------- + # `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays + # `fill_array_details` contains arrays operations to be + # performed on the arguments + # `parse_args` are used to capture the args from CPython + # `pass_args` are the args that are passed to the shared library function + fill_array_details = "" + parse_args = "" + pass_args = "" + numpy_init = "" + for i, t in self.arg_types.items(): + if i > 1: + parse_args += ", " + pass_args += ", " + if isinstance(t, list): + if numpy_init == "": + numpy_init = "// Initialize NumPy\n import_array();\n\n " + fill_array_details += f"""\n + // fill array details for args[{i-1}] + if (PyArray_NDIM(arg_{i}) != 1) {{ + PyErr_SetString(PyExc_TypeError, + "Only 1 dimension is implemented for now."); + return NULL; + }} + + {t[1]}s_array_{i} = malloc(sizeof(struct r64)); + {{ + {t[2]}array; + // Create C arrays from numpy objects: + PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{i})); + npy_intp dims[1]; + if (PyArray_AsCArray((PyObject **)&arg_{i}, (void *)&array, dims, 1, descr) < 0) {{ + PyErr_SetString(PyExc_TypeError, "error converting to c array"); + return NULL; + }} + + s_array_{i}->data = array; + s_array_{i}->n_dims = 1; + s_array_{i}->dims[0].lower_bound = 0; + s_array_{i}->dims[0].length = dims[0]; + s_array_{i}->is_allocated = false; + }}""" + pass_args += "s_array_" + str(i) + else: + pass_args += "arg_" + str(i) + variables_decl += " " + get_data_type(t) + "arg_" + str(i) + ";\n" + parse_args += "&arg_" + str(i) + + if parse_args != "": + parse_args = f"""\n // Parse the arguments from Python + if (!PyArg_ParseTuple(args, "{self.arg_type_formats}", {parse_args})) {{ + return NULL; + }}""" + + # ---------------------------------------------------------------------- + # Handle the return variable if any; otherwise, return None + fill_return_details = "" + if self.return_type != "": + fill_return_details = f"""\n\n // Call the C function + arg_0 = {self.fn_name}({pass_args}); + + // Build and return the result as a Python object + return Py_BuildValue("{self.return_type_format}", arg_0);""" + else: + fill_return_details = f"""{self.fn_name}({pass_args}); + Py_RETURN_NONE;""" + + # ---------------------------------------------------------------------- + # Python wrapper for the Shared library + template = f"""// Python headers +#include + +// NumPy C/API headers +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings +#include + +// LPython generated C code +#include "a.h" + +// Define the Python module and method mappings +static PyObject* define_module(PyObject* self, PyObject* args) {{ + {numpy_init}{variables_decl}{parse_args}\ +{fill_array_details}{fill_return_details} +}} + +// Define the module's method table +static PyMethodDef module_methods[] = {{ + {{"{self.fn_name}", define_module, METH_VARARGS, + "Handle arguments & return variable and call the function"}}, + {{NULL, NULL, 0, NULL}} +}}; + +// Define the module initialization function +static struct PyModuleDef module_def = {{ + PyModuleDef_HEAD_INIT, + "lpython_jit_module", + "Shared library to use LPython generated functions", + -1, + module_methods +}}; + +PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{ + PyObject* module; + + // Create the module object + module = PyModule_Create(&module_def); + if (!module) {{ + return NULL; + }} + + return module; +}} +""" + # ---------------------------------------------------------------------- + # Write the C source code to the file + with open("a.c", "w") as file: + file.write(template) + + # ---------------------------------------------------------------------- + # Generate the Shared library + # TODO: Use LLVM instead of C backend + r = os.system("lpython --show-c --disable-main a.py > a.h") + assert r == 0, "Failed to create C file" + gcc_flags = "" + if platform.system() == "Linux": + gcc_flags = " -shared -fPIC " + elif platform.system() == "Darwin": + gcc_flags = " -bundle -flat_namespace -undefined suppress " + else: + raise NotImplementedError("Platform not implemented") + python_path = "-I" + get_python_inc() + " " + numpy_path = "-I" + get_include() + rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime " + rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \ + + get_rtlib_dir() + " -llpython_runtime " + python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm" + r = os.system("gcc -g" + gcc_flags + python_path + numpy_path + + " a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib) + assert r == 0, "Failed to create the shared library" + + def __call__(self, *args, **kwargs): + import sys; sys.path.append('.') + # import the symbol from the shared library + function = getattr(__import__("lpython_jit_module"), self.fn_name) + return function(*args, **kwargs)