@@ -517,6 +517,37 @@ def ccallback(f):
517
517
518
518
class jit :
519
519
def __init__ (self , function ):
520
+ def get_rtlib_dir ():
521
+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
522
+ return os .path .join (current_dir , ".." )
523
+
524
+ def get_type_info (arg ):
525
+ # return_type -> (`type_format`, `variable type`, `array struct name`)
526
+ # See: https://docs.python.org/3/c-api/arg.html for more info on type_format
527
+ if arg == f64 :
528
+ return ('d' , "double" , 'r64' )
529
+ elif arg == f32 :
530
+ return ('f' , "float" , 'r32' )
531
+ elif arg == i64 :
532
+ return ('l' , "long int" , 'i64' )
533
+ elif arg == i32 :
534
+ return ('i' , "int" , 'i32' )
535
+ elif arg == bool :
536
+ return ('p' , "bool" , '' )
537
+ elif isinstance (arg , Array ):
538
+ t = get_type_info (arg ._type )
539
+ if t [2 ] == '' :
540
+ raise NotImplementedError ("Type %r not implemented" % arg )
541
+ return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" ], '' )
542
+ else :
543
+ raise NotImplementedError ("Type %r not implemented" % arg )
544
+
545
+ def get_data_type (t ):
546
+ if isinstance (t , list ):
547
+ return t [0 ]
548
+ else :
549
+ return t + " "
550
+
520
551
self .fn_name = function .__name__
521
552
# Get the source code of the function
522
553
source_code = getsource (function )
@@ -530,6 +561,148 @@ def __init__(self, function):
530
561
# Write the Python source code to the file
531
562
file .write ("@ccallable" )
532
563
file .write (source_code )
564
+ # ----------------------------------------------------------------------
565
+ types = function .__annotations__
566
+ self .arg_type_formats = ""
567
+ self .return_type = ""
568
+ self .return_type_format = ""
569
+ self .arg_types = {}
570
+ counter = 1
571
+ for t in types .keys ():
572
+ if t == "return" :
573
+ type = get_type_info (types [t ])
574
+ self .return_type_format = type [0 ]
575
+ self .return_type = type [1 ]
576
+ else :
577
+ type = get_type_info (types [t ])
578
+ self .arg_type_formats += type [0 ]
579
+ self .arg_types [counter ] = type [1 ]
580
+ counter += 1
581
+ # ----------------------------------------------------------------------
582
+ # `arg_0`: used as the return variables
583
+ # arguments are declared as `arg_1`, `arg_2`, ...
584
+ variables_decl = ""
585
+ if self .return_type != "" :
586
+ variables_decl = "// Declare return variables and arguments\n "
587
+ variables_decl += " " + get_data_type (self .return_type ) + "arg_" \
588
+ + str (0 ) + ";\n "
589
+ # ----------------------------------------------------------------------
590
+ # `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
591
+ # `fill_array_details` contains arrays operations to be
592
+ # performed on the arguments
593
+ # `parse_args` are used to capture the args from CPython
594
+ # `pass_args` are the args that are passed to the shared library function
595
+ fill_array_details = ""
596
+ parse_args = ""
597
+ pass_args = ""
598
+ numpy_init = ""
599
+ for i , t in self .arg_types .items ():
600
+ if i > 1 :
601
+ parse_args += ", "
602
+ pass_args += ", "
603
+ if isinstance (t , list ):
604
+ if numpy_init == "" :
605
+ numpy_init = "// Initialize NumPy\n import_array();\n \n "
606
+ fill_array_details += f"""\n
607
+ // fill array details for args[{ i - 1 } ]
608
+ if (PyArray_NDIM(arg_{ i } ) != 1) {{
609
+ PyErr_SetString(PyExc_TypeError,
610
+ "Only 1 dimension is implemented for now.");
611
+ return NULL;
612
+ }}
613
+
614
+ { t [1 ]} s_array_{ i } = malloc(sizeof(struct r64));
615
+ {{
616
+ { t [2 ]} array;
617
+ // Create C arrays from numpy objects:
618
+ PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{ i } ));
619
+ npy_intp dims[1];
620
+ if (PyArray_AsCArray((PyObject **)&arg_{ i } , (void *)&array, dims, 1, descr) < 0) {{
621
+ PyErr_SetString(PyExc_TypeError, "error converting to c array");
622
+ return NULL;
623
+ }}
624
+
625
+ s_array_{ i } ->data = array;
626
+ s_array_{ i } ->n_dims = 1;
627
+ s_array_{ i } ->dims[0].lower_bound = 0;
628
+ s_array_{ i } ->dims[0].length = dims[0];
629
+ s_array_{ i } ->is_allocated = false;
630
+ }}"""
631
+ pass_args += "s_array_" + str (i )
632
+ else :
633
+ pass_args += "arg_" + str (i )
634
+ variables_decl += " " + get_data_type (t ) + "arg_" + str (i ) + ";\n "
635
+ parse_args += "&arg_" + str (i )
636
+
637
+ if parse_args != "" :
638
+ parse_args = f"""\n // Parse the arguments from Python
639
+ if (!PyArg_ParseTuple(args, "{ self .arg_type_formats } ", { parse_args } )) {{
640
+ return NULL;
641
+ }}"""
642
+
643
+ # ----------------------------------------------------------------------
644
+ # Handle the return variable if any; otherwise, return None
645
+ fill_return_details = ""
646
+ if self .return_type != "" :
647
+ fill_return_details = f"""\n \n // Call the C function
648
+ arg_0 = { self .fn_name } ({ pass_args } );
649
+
650
+ // Build and return the result as a Python object
651
+ return Py_BuildValue("{ self .return_type_format } ", arg_0);"""
652
+ else :
653
+ fill_return_details = f"""{ self .fn_name } ({ pass_args } );
654
+ Py_RETURN_NONE;"""
655
+
656
+ # ----------------------------------------------------------------------
657
+ # Python wrapper for the Shared library
658
+ template = f"""// Python headers
659
+ #include <Python.h>
660
+
661
+ // NumPy C/API headers
662
+ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings
663
+ #include <numpy/ndarrayobject.h>
664
+
665
+ // LPython generated C code
666
+ #include "a.h"
667
+
668
+ // Define the Python module and method mappings
669
+ static PyObject* define_module(PyObject* self, PyObject* args) {{
670
+ { numpy_init } { variables_decl } { parse_args } \
671
+ { fill_array_details } { fill_return_details }
672
+ }}
673
+
674
+ // Define the module's method table
675
+ static PyMethodDef module_methods[] = {{
676
+ {{"{ self .fn_name } ", define_module, METH_VARARGS,
677
+ "Handle arguments & return variable and call the function"}},
678
+ {{NULL, NULL, 0, NULL}}
679
+ }};
680
+
681
+ // Define the module initialization function
682
+ static struct PyModuleDef module_def = {{
683
+ PyModuleDef_HEAD_INIT,
684
+ "lpython_jit_module",
685
+ "Shared library to use LPython generated functions",
686
+ -1,
687
+ module_methods
688
+ }};
689
+
690
+ PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
691
+ PyObject* module;
692
+
693
+ // Create the module object
694
+ module = PyModule_Create(&module_def);
695
+ if (!module) {{
696
+ return NULL;
697
+ }}
698
+
699
+ return module;
700
+ }}
701
+ """
702
+ # ----------------------------------------------------------------------
703
+ # Write the C source code to the file
704
+ with open ("a.c" , "w" ) as file :
705
+ file .write (template )
533
706
534
707
# ----------------------------------------------------------------------
535
708
# TODO: Use LLVM instead of C backend
0 commit comments