Skip to content

Commit aa8142a

Browse files
committed
feat: automatically run jit
1 parent 3768761 commit aa8142a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/serialization/EnzymeJAX.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Then in Python:
7474
from exported.my_function import run_my_function
7575
import jax
7676
77-
result = jax.jit(run_my_function)(*inputs)
77+
result = run_my_function(*inputs)
7878
```
7979
"""
8080
function export_to_enzymejax(
@@ -220,6 +220,7 @@ function _generate_python_script(
220220
inputs = [$(join(load_inputs, ", "))]
221221
return tuple(inputs)
222222
223+
@jax.jit
223224
def run_$(function_name)($(arg_list)):
224225
\"\"\"
225226
Call the exported Julia function via EnzymeJAX.
@@ -248,7 +249,7 @@ function _generate_python_script(
248249
249250
# Run the function (with JIT compilation)
250251
print(\"Running $(function_name) with JIT compilation...\")
251-
result = jax.jit(run_$(function_name))(*inputs)
252+
result = run_$(function_name)(*inputs)
252253
print(\"Result:\", result)
253254
"""
254255

0 commit comments

Comments
 (0)