File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -74,7 +74,7 @@ Then in Python:
7474from exported.my_function import run_my_function
7575import jax
7676
77- result = jax.jit( run_my_function) (*inputs)
77+ result = run_my_function(*inputs)
7878```
7979"""
8080function export_to_enzymejax (
@@ -220,6 +220,7 @@ function _generate_python_script(
220220 inputs = [$(join (load_inputs, " , " )) ]
221221 return tuple(inputs)
222222
223+ @jax.jit
223224 def run_$(function_name) ($(arg_list) ):
224225 \"\"\"
225226 Call the exported Julia function via EnzymeJAX.
@@ -248,7 +249,7 @@ function _generate_python_script(
248249
249250 # Run the function (with JIT compilation)
250251 print(\" Running $(function_name) with JIT compilation...\" )
251- result = jax.jit( run_$(function_name) )(*inputs)
252+ result = run_$(function_name) (*inputs)
252253 print(\" Result:\" , result)
253254 """
254255
You can’t perform that action at this time.
0 commit comments