Skip to content

Commit 9d965a8

Browse files
authored
Merge pull request #1708 from harshsingh-24/jit
`@jit` interface for LPython
2 parents 9028f68 + a5eda0c commit 9d965a8

File tree

5 files changed

+247
-2
lines changed

5 files changed

+247
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ inst/bin/*
8989
*_ldd.txt
9090
*_lines.dat.txt
9191
*__tmp__generated__.c
92+
a.c
93+
a.h
94+
a.py
9295

9396
### https://raw.github.com/github/gitignore/218a941be92679ce67d0484547e3e142b2f5f6f0/Global/macOS.gitignore
9497

integration_tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,6 @@ RUN(NAME callback_01 LABELS cpython llvm)
535535
RUN(NAME intrinsics_01 LABELS cpython llvm) # any
536536

537537
COMPILE(NAME import_order_01 LABELS cpython llvm c) # any
538+
539+
# Jit
540+
RUN(NAME test_jit_01 LABELS cpython)

integration_tests/test_jit_01.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from numpy import array
2+
from lpython import i32, f64, jit
3+
4+
@jit
5+
def fast_sum(n: i32, x: f64[:]) -> f64:
6+
s: f64 = 0.0
7+
i: i32
8+
for i in range(n):
9+
s += x[i]
10+
return s
11+
12+
def test():
13+
x: f64[3] = array([1.0, 2.0, 3.0])
14+
assert fast_sum(3, x) == 6.0
15+
16+
test()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3574,9 +3574,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
35743574
is_inline = true;
35753575
} else if (name == "static") {
35763576
is_static = true;
3577+
} else if (name == "jit") {
3578+
throw SemanticError("`@lpython.jit` decorator must be "
3579+
"run from CPython, not compiled using LPython",
3580+
dec->base.loc);
35773581
} else {
35783582
throw SemanticError("Decorator: " + name + " is not supported",
3579-
x.base.base.loc);
3583+
dec->base.loc);
35803584
}
35813585
} else if (AST::is_a<AST::Call_t>(*dec)) {
35823586
AST::Call_t *call_d = AST::down_cast<AST::Call_t>(dec);

src/runtime/lpython/lpython.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from inspect import getfullargspec, getcallargs, isclass
1+
from inspect import getfullargspec, getcallargs, isclass, getsource
22
import os
33
import ctypes
44
import platform
55
from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
66
from goto import with_goto
7+
from numpy import get_include
8+
from distutils.sysconfig import get_python_inc
79

810
# TODO: this does not seem to restrict other imports
911
__slots__ = ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "c32", "c64", "CPtr",
@@ -522,3 +524,220 @@ def ccallable(f):
522524

523525
def ccallback(f):
524526
return f
527+
528+
class jit:
529+
def __init__(self, function):
530+
def get_rtlib_dir():
531+
current_dir = os.path.dirname(os.path.abspath(__file__))
532+
return os.path.join(current_dir, "..")
533+
534+
def get_type_info(arg):
535+
# return_type -> (`type_format`, `variable type`, `array struct name`)
536+
# See: https://docs.python.org/3/c-api/arg.html for more info on type_format
537+
if arg == f64:
538+
return ('d', "double", 'r64')
539+
elif arg == f32:
540+
return ('f', "float", 'r32')
541+
elif arg == i64:
542+
return ('l', "long int", 'i64')
543+
elif arg == i32:
544+
return ('i', "int", 'i32')
545+
elif arg == bool:
546+
return ('p', "bool", '')
547+
elif isinstance(arg, Array):
548+
t = get_type_info(arg._type)
549+
if t[2] == '':
550+
raise NotImplementedError("Type %r not implemented" % arg)
551+
return ('O', ["PyArrayObject *", "struct "+t[2]+" *", t[1]+" *"], '')
552+
else:
553+
raise NotImplementedError("Type %r not implemented" % arg)
554+
555+
def get_data_type(t):
556+
if isinstance(t, list):
557+
return t[0]
558+
else:
559+
return t + " "
560+
561+
self.fn_name = function.__name__
562+
# Get the source code of the function
563+
source_code = getsource(function)
564+
source_code = source_code[source_code.find('\n'):]
565+
566+
# TODO: Create a filename based on the function name
567+
# filename = function.__name__ + ".py"
568+
569+
# Open the file for writing
570+
with open("a.py", "w") as file:
571+
# Write the Python source code to the file
572+
file.write("@ccallable")
573+
file.write(source_code)
574+
# ----------------------------------------------------------------------
575+
types = function.__annotations__
576+
self.arg_type_formats = ""
577+
self.return_type = ""
578+
self.return_type_format = ""
579+
self.arg_types = {}
580+
counter = 1
581+
for t in types.keys():
582+
if t == "return":
583+
type = get_type_info(types[t])
584+
self.return_type_format = type[0]
585+
self.return_type = type[1]
586+
else:
587+
type = get_type_info(types[t])
588+
self.arg_type_formats += type[0]
589+
self.arg_types[counter] = type[1]
590+
counter += 1
591+
# ----------------------------------------------------------------------
592+
# `arg_0`: used as the return variables
593+
# arguments are declared as `arg_1`, `arg_2`, ...
594+
variables_decl = ""
595+
if self.return_type != "":
596+
variables_decl = "// Declare return variables and arguments\n"
597+
variables_decl += " " + get_data_type(self.return_type) + "arg_" \
598+
+ str(0) + ";\n"
599+
# ----------------------------------------------------------------------
600+
# `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
601+
# `fill_array_details` contains arrays operations to be
602+
# performed on the arguments
603+
# `parse_args` are used to capture the args from CPython
604+
# `pass_args` are the args that are passed to the shared library function
605+
fill_array_details = ""
606+
parse_args = ""
607+
pass_args = ""
608+
numpy_init = ""
609+
for i, t in self.arg_types.items():
610+
if i > 1:
611+
parse_args += ", "
612+
pass_args += ", "
613+
if isinstance(t, list):
614+
if numpy_init == "":
615+
numpy_init = "// Initialize NumPy\n import_array();\n\n "
616+
fill_array_details += f"""\n
617+
// fill array details for args[{i-1}]
618+
if (PyArray_NDIM(arg_{i}) != 1) {{
619+
PyErr_SetString(PyExc_TypeError,
620+
"Only 1 dimension is implemented for now.");
621+
return NULL;
622+
}}
623+
624+
{t[1]}s_array_{i} = malloc(sizeof(struct r64));
625+
{{
626+
{t[2]}array;
627+
// Create C arrays from numpy objects:
628+
PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{i}));
629+
npy_intp dims[1];
630+
if (PyArray_AsCArray((PyObject **)&arg_{i}, (void *)&array, dims, 1, descr) < 0) {{
631+
PyErr_SetString(PyExc_TypeError, "error converting to c array");
632+
return NULL;
633+
}}
634+
635+
s_array_{i}->data = array;
636+
s_array_{i}->n_dims = 1;
637+
s_array_{i}->dims[0].lower_bound = 0;
638+
s_array_{i}->dims[0].length = dims[0];
639+
s_array_{i}->is_allocated = false;
640+
}}"""
641+
pass_args += "s_array_" + str(i)
642+
else:
643+
pass_args += "arg_" + str(i)
644+
variables_decl += " " + get_data_type(t) + "arg_" + str(i) + ";\n"
645+
parse_args += "&arg_" + str(i)
646+
647+
if parse_args != "":
648+
parse_args = f"""\n // Parse the arguments from Python
649+
if (!PyArg_ParseTuple(args, "{self.arg_type_formats}", {parse_args})) {{
650+
return NULL;
651+
}}"""
652+
653+
# ----------------------------------------------------------------------
654+
# Handle the return variable if any; otherwise, return None
655+
fill_return_details = ""
656+
if self.return_type != "":
657+
fill_return_details = f"""\n\n // Call the C function
658+
arg_0 = {self.fn_name}({pass_args});
659+
660+
// Build and return the result as a Python object
661+
return Py_BuildValue("{self.return_type_format}", arg_0);"""
662+
else:
663+
fill_return_details = f"""{self.fn_name}({pass_args});
664+
Py_RETURN_NONE;"""
665+
666+
# ----------------------------------------------------------------------
667+
# Python wrapper for the Shared library
668+
template = f"""// Python headers
669+
#include <Python.h>
670+
671+
// NumPy C/API headers
672+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings
673+
#include <numpy/ndarrayobject.h>
674+
675+
// LPython generated C code
676+
#include "a.h"
677+
678+
// Define the Python module and method mappings
679+
static PyObject* define_module(PyObject* self, PyObject* args) {{
680+
{numpy_init}{variables_decl}{parse_args}\
681+
{fill_array_details}{fill_return_details}
682+
}}
683+
684+
// Define the module's method table
685+
static PyMethodDef module_methods[] = {{
686+
{{"{self.fn_name}", define_module, METH_VARARGS,
687+
"Handle arguments & return variable and call the function"}},
688+
{{NULL, NULL, 0, NULL}}
689+
}};
690+
691+
// Define the module initialization function
692+
static struct PyModuleDef module_def = {{
693+
PyModuleDef_HEAD_INIT,
694+
"lpython_jit_module",
695+
"Shared library to use LPython generated functions",
696+
-1,
697+
module_methods
698+
}};
699+
700+
PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
701+
PyObject* module;
702+
703+
// Create the module object
704+
module = PyModule_Create(&module_def);
705+
if (!module) {{
706+
return NULL;
707+
}}
708+
709+
return module;
710+
}}
711+
"""
712+
# ----------------------------------------------------------------------
713+
# Write the C source code to the file
714+
with open("a.c", "w") as file:
715+
file.write(template)
716+
717+
# ----------------------------------------------------------------------
718+
# Generate the Shared library
719+
# TODO: Use LLVM instead of C backend
720+
r = os.system("lpython --show-c --disable-main a.py > a.h")
721+
assert r == 0, "Failed to create C file"
722+
gcc_flags = ""
723+
if platform.system() == "Linux":
724+
gcc_flags = " -shared -fPIC "
725+
elif platform.system() == "Darwin":
726+
gcc_flags = " -bundle -flat_namespace -undefined suppress "
727+
else:
728+
raise NotImplementedError("Platform not implemented")
729+
python_path = "-I" + get_python_inc() + " "
730+
numpy_path = "-I" + get_include()
731+
rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime "
732+
rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \
733+
+ get_rtlib_dir() + " -llpython_runtime "
734+
python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm"
735+
r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
736+
" a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib)
737+
assert r == 0, "Failed to create the shared library"
738+
739+
def __call__(self, *args, **kwargs):
740+
import sys; sys.path.append('.')
741+
# import the symbol from the shared library
742+
function = getattr(__import__("lpython_jit_module"), self.fn_name)
743+
return function(*args, **kwargs)

0 commit comments

Comments
 (0)