|
1 |
| -from inspect import getfullargspec, getcallargs, isclass |
| 1 | +from inspect import getfullargspec, getcallargs, isclass, getsource |
2 | 2 | import os
|
3 | 3 | import ctypes
|
4 | 4 | import platform
|
5 | 5 | from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
|
6 | 6 | from goto import with_goto
|
| 7 | +from numpy import get_include |
| 8 | +from distutils.sysconfig import get_python_inc |
7 | 9 |
|
8 | 10 | # TODO: this does not seem to restrict other imports
|
9 | 11 | __slots__ = ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "c32", "c64", "CPtr",
|
@@ -522,3 +524,220 @@ def ccallable(f):
|
522 | 524 |
|
523 | 525 | def ccallback(f):
|
524 | 526 | return f
|
| 527 | + |
| 528 | +class jit: |
| 529 | + def __init__(self, function): |
| 530 | + def get_rtlib_dir(): |
| 531 | + current_dir = os.path.dirname(os.path.abspath(__file__)) |
| 532 | + return os.path.join(current_dir, "..") |
| 533 | + |
| 534 | + def get_type_info(arg): |
| 535 | + # return_type -> (`type_format`, `variable type`, `array struct name`) |
| 536 | + # See: https://docs.python.org/3/c-api/arg.html for more info on type_format |
| 537 | + if arg == f64: |
| 538 | + return ('d', "double", 'r64') |
| 539 | + elif arg == f32: |
| 540 | + return ('f', "float", 'r32') |
| 541 | + elif arg == i64: |
| 542 | + return ('l', "long int", 'i64') |
| 543 | + elif arg == i32: |
| 544 | + return ('i', "int", 'i32') |
| 545 | + elif arg == bool: |
| 546 | + return ('p', "bool", '') |
| 547 | + elif isinstance(arg, Array): |
| 548 | + t = get_type_info(arg._type) |
| 549 | + if t[2] == '': |
| 550 | + raise NotImplementedError("Type %r not implemented" % arg) |
| 551 | + return ('O', ["PyArrayObject *", "struct "+t[2]+" *", t[1]+" *"], '') |
| 552 | + else: |
| 553 | + raise NotImplementedError("Type %r not implemented" % arg) |
| 554 | + |
| 555 | + def get_data_type(t): |
| 556 | + if isinstance(t, list): |
| 557 | + return t[0] |
| 558 | + else: |
| 559 | + return t + " " |
| 560 | + |
| 561 | + self.fn_name = function.__name__ |
| 562 | + # Get the source code of the function |
| 563 | + source_code = getsource(function) |
| 564 | + source_code = source_code[source_code.find('\n'):] |
| 565 | + |
| 566 | + # TODO: Create a filename based on the function name |
| 567 | + # filename = function.__name__ + ".py" |
| 568 | + |
| 569 | + # Open the file for writing |
| 570 | + with open("a.py", "w") as file: |
| 571 | + # Write the Python source code to the file |
| 572 | + file.write("@ccallable") |
| 573 | + file.write(source_code) |
| 574 | + # ---------------------------------------------------------------------- |
| 575 | + types = function.__annotations__ |
| 576 | + self.arg_type_formats = "" |
| 577 | + self.return_type = "" |
| 578 | + self.return_type_format = "" |
| 579 | + self.arg_types = {} |
| 580 | + counter = 1 |
| 581 | + for t in types.keys(): |
| 582 | + if t == "return": |
| 583 | + type = get_type_info(types[t]) |
| 584 | + self.return_type_format = type[0] |
| 585 | + self.return_type = type[1] |
| 586 | + else: |
| 587 | + type = get_type_info(types[t]) |
| 588 | + self.arg_type_formats += type[0] |
| 589 | + self.arg_types[counter] = type[1] |
| 590 | + counter += 1 |
| 591 | + # ---------------------------------------------------------------------- |
| 592 | + # `arg_0`: used as the return variables |
| 593 | + # arguments are declared as `arg_1`, `arg_2`, ... |
| 594 | + variables_decl = "" |
| 595 | + if self.return_type != "": |
| 596 | + variables_decl = "// Declare return variables and arguments\n" |
| 597 | + variables_decl += " " + get_data_type(self.return_type) + "arg_" \ |
| 598 | + + str(0) + ";\n" |
| 599 | + # ---------------------------------------------------------------------- |
| 600 | + # `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays |
| 601 | + # `fill_array_details` contains arrays operations to be |
| 602 | + # performed on the arguments |
| 603 | + # `parse_args` are used to capture the args from CPython |
| 604 | + # `pass_args` are the args that are passed to the shared library function |
| 605 | + fill_array_details = "" |
| 606 | + parse_args = "" |
| 607 | + pass_args = "" |
| 608 | + numpy_init = "" |
| 609 | + for i, t in self.arg_types.items(): |
| 610 | + if i > 1: |
| 611 | + parse_args += ", " |
| 612 | + pass_args += ", " |
| 613 | + if isinstance(t, list): |
| 614 | + if numpy_init == "": |
| 615 | + numpy_init = "// Initialize NumPy\n import_array();\n\n " |
| 616 | + fill_array_details += f"""\n |
| 617 | + // fill array details for args[{i-1}] |
| 618 | + if (PyArray_NDIM(arg_{i}) != 1) {{ |
| 619 | + PyErr_SetString(PyExc_TypeError, |
| 620 | + "Only 1 dimension is implemented for now."); |
| 621 | + return NULL; |
| 622 | + }} |
| 623 | +
|
| 624 | + {t[1]}s_array_{i} = malloc(sizeof(struct r64)); |
| 625 | + {{ |
| 626 | + {t[2]}array; |
| 627 | + // Create C arrays from numpy objects: |
| 628 | + PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{i})); |
| 629 | + npy_intp dims[1]; |
| 630 | + if (PyArray_AsCArray((PyObject **)&arg_{i}, (void *)&array, dims, 1, descr) < 0) {{ |
| 631 | + PyErr_SetString(PyExc_TypeError, "error converting to c array"); |
| 632 | + return NULL; |
| 633 | + }} |
| 634 | +
|
| 635 | + s_array_{i}->data = array; |
| 636 | + s_array_{i}->n_dims = 1; |
| 637 | + s_array_{i}->dims[0].lower_bound = 0; |
| 638 | + s_array_{i}->dims[0].length = dims[0]; |
| 639 | + s_array_{i}->is_allocated = false; |
| 640 | + }}""" |
| 641 | + pass_args += "s_array_" + str(i) |
| 642 | + else: |
| 643 | + pass_args += "arg_" + str(i) |
| 644 | + variables_decl += " " + get_data_type(t) + "arg_" + str(i) + ";\n" |
| 645 | + parse_args += "&arg_" + str(i) |
| 646 | + |
| 647 | + if parse_args != "": |
| 648 | + parse_args = f"""\n // Parse the arguments from Python |
| 649 | + if (!PyArg_ParseTuple(args, "{self.arg_type_formats}", {parse_args})) {{ |
| 650 | + return NULL; |
| 651 | + }}""" |
| 652 | + |
| 653 | + # ---------------------------------------------------------------------- |
| 654 | + # Handle the return variable if any; otherwise, return None |
| 655 | + fill_return_details = "" |
| 656 | + if self.return_type != "": |
| 657 | + fill_return_details = f"""\n\n // Call the C function |
| 658 | + arg_0 = {self.fn_name}({pass_args}); |
| 659 | +
|
| 660 | + // Build and return the result as a Python object |
| 661 | + return Py_BuildValue("{self.return_type_format}", arg_0);""" |
| 662 | + else: |
| 663 | + fill_return_details = f"""{self.fn_name}({pass_args}); |
| 664 | + Py_RETURN_NONE;""" |
| 665 | + |
| 666 | + # ---------------------------------------------------------------------- |
| 667 | + # Python wrapper for the Shared library |
| 668 | + template = f"""// Python headers |
| 669 | +#include <Python.h> |
| 670 | +
|
| 671 | +// NumPy C/API headers |
| 672 | +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings |
| 673 | +#include <numpy/ndarrayobject.h> |
| 674 | +
|
| 675 | +// LPython generated C code |
| 676 | +#include "a.h" |
| 677 | +
|
| 678 | +// Define the Python module and method mappings |
| 679 | +static PyObject* define_module(PyObject* self, PyObject* args) {{ |
| 680 | + {numpy_init}{variables_decl}{parse_args}\ |
| 681 | +{fill_array_details}{fill_return_details} |
| 682 | +}} |
| 683 | +
|
| 684 | +// Define the module's method table |
| 685 | +static PyMethodDef module_methods[] = {{ |
| 686 | + {{"{self.fn_name}", define_module, METH_VARARGS, |
| 687 | + "Handle arguments & return variable and call the function"}}, |
| 688 | + {{NULL, NULL, 0, NULL}} |
| 689 | +}}; |
| 690 | +
|
| 691 | +// Define the module initialization function |
| 692 | +static struct PyModuleDef module_def = {{ |
| 693 | + PyModuleDef_HEAD_INIT, |
| 694 | + "lpython_jit_module", |
| 695 | + "Shared library to use LPython generated functions", |
| 696 | + -1, |
| 697 | + module_methods |
| 698 | +}}; |
| 699 | +
|
| 700 | +PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{ |
| 701 | + PyObject* module; |
| 702 | +
|
| 703 | + // Create the module object |
| 704 | + module = PyModule_Create(&module_def); |
| 705 | + if (!module) {{ |
| 706 | + return NULL; |
| 707 | + }} |
| 708 | +
|
| 709 | + return module; |
| 710 | +}} |
| 711 | +""" |
| 712 | + # ---------------------------------------------------------------------- |
| 713 | + # Write the C source code to the file |
| 714 | + with open("a.c", "w") as file: |
| 715 | + file.write(template) |
| 716 | + |
| 717 | + # ---------------------------------------------------------------------- |
| 718 | + # Generate the Shared library |
| 719 | + # TODO: Use LLVM instead of C backend |
| 720 | + r = os.system("lpython --show-c --disable-main a.py > a.h") |
| 721 | + assert r == 0, "Failed to create C file" |
| 722 | + gcc_flags = "" |
| 723 | + if platform.system() == "Linux": |
| 724 | + gcc_flags = " -shared -fPIC " |
| 725 | + elif platform.system() == "Darwin": |
| 726 | + gcc_flags = " -bundle -flat_namespace -undefined suppress " |
| 727 | + else: |
| 728 | + raise NotImplementedError("Platform not implemented") |
| 729 | + python_path = "-I" + get_python_inc() + " " |
| 730 | + numpy_path = "-I" + get_include() |
| 731 | + rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime " |
| 732 | + rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \ |
| 733 | + + get_rtlib_dir() + " -llpython_runtime " |
| 734 | + python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm" |
| 735 | + r = os.system("gcc -g" + gcc_flags + python_path + numpy_path + |
| 736 | + " a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib) |
| 737 | + assert r == 0, "Failed to create the shared library" |
| 738 | + |
| 739 | + def __call__(self, *args, **kwargs): |
| 740 | + import sys; sys.path.append('.') |
| 741 | + # import the symbol from the shared library |
| 742 | + function = getattr(__import__("lpython_jit_module"), self.fn_name) |
| 743 | + return function(*args, **kwargs) |
0 commit comments