-
Notifications
You must be signed in to change notification settings - Fork 38
Add export_to_enzymejax for automated JAX/EnzymeAD integration #1934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit
JuliaFormatter
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 7 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 13 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 17 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Lines 20 to 22 in 3bd8d73
| simple_add, x, y; | |
| output_dir=tmpdir, | |
| function_name="simple_add" |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 24 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 30 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 35 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 43 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax.jl
Line 48 in 3bd8d73
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax_comprehensive.jl
Lines 19 to 21 in 3bd8d73
| matrix_multiply, x, y; | |
| output_dir=tmpdir, | |
| function_name="matrix_multiply" |
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax_comprehensive.jl
Lines 60 to 62 in 3bd8d73
| add_3d, x, y; | |
| output_dir=tmpdir, | |
| function_name="add_3d" |
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/export_enzymeax_comprehensive.jl
Lines 90 to 92 in 3bd8d73
| simple_fn, x; | |
| output_dir=tmpdir, | |
| function_name="test_fn" |
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
[JuliaFormatter] reported by reviewdog 🐶
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/export_enzymeax.jl
Outdated
| @testset "Export to EnzymeJAX" begin | ||
| # Create a temporary directory for the export | ||
| tmpdir = mktempdir() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/export_enzymeax.jl
Outdated
| function simple_add(x, y) | ||
| return x .+ y | ||
| end | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/export_enzymeax.jl
Outdated
| # Create some example inputs | ||
| x = Reactant.to_rarray(Float32[1, 2, 3]) | ||
| y = Reactant.to_rarray(Float32[4, 5, 6]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/export_enzymeax.jl
Outdated
| simple_add, x, y; | ||
| output_dir=tmpdir, | ||
| function_name="simple_add" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| simple_add, x, y; | |
| output_dir=tmpdir, | |
| function_name="simple_add" | |
| simple_add, x, y; output_dir=tmpdir, function_name="simple_add" |
test/export_enzymeax.jl
Outdated
| output_dir=tmpdir, | ||
| function_name="simple_add" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| @test occursin("(4, 1, 28, 28)", python_content) | ||
|
|
||
| println("✓ 3D array export test passed!") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
|
||
| @testset "Export to EnzymeJAX - File Content Verification" begin | ||
| tmpdir = mktempdir() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| function simple_fn(x) | ||
| return x .* 2.0f0 | ||
| end | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| end | ||
|
|
||
| x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| simple_fn, x; | ||
| output_dir=tmpdir, | ||
| function_name="test_fn" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| simple_fn, x; | |
| output_dir=tmpdir, | |
| function_name="test_fn" | |
| simple_fn, x; output_dir=tmpdir, function_name="test_fn" |
src/serialization/EnzymeJAX.jl
Outdated
|
|
||
| for (i, arg) in enumerate(args) | ||
| # Convert to array if needed | ||
| arr = _to_array(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to use the existing seen_args infrastructure. Now iterating directly over mlir_fn_res.seen_args (OrderedIdDict mapping concrete args → traced args) and filtering by linear_args instead of creating a manual inverse map. Commit: 89b7690
src/serialization/Serialization.jl
Outdated
| end | ||
|
|
||
| """ | ||
| export_to_enzymeax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot rename to export_to_enzymejax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to export_to_enzymejax across all files. Commit: a53e20e
|
On running the export pipeline (identical to how """
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
This script was generated by Reactant.Serialization.export_to_enzymeax().
"""
from enzyme_ad.jax import hlo_call
import jax
import jax.numpy as jnp
import numpy as np
import os
# Get the directory of this script
_script_dir = os.path.dirname(os.path.abspath(__file__))
# Load the MLIR/StableHLO code
with open(os.path.join(_script_dir, "my_function_0.mlir"), "r") as f:
_hlo_code = f.read()
def load_inputs():
"""Load the example inputs that were exported from Julia."""
inputs = []
inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_1.npy")))
inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_2.npy")))
inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_3.npy")))
return tuple(inputs)
def run_my_function(arg1, arg2, arg3):
"""
Call the exported Julia function via EnzymeJAX.
Args:
arg1: Array of shape (3,) and dtype float32
arg2: Array of shape (3,) and dtype float32
arg3: Array of shape (3,) and dtype float32
Returns:
The result of calling the exported function.
Note:
All inputs must be in row-major (Python/NumPy) order. If you're passing
arrays from Julia, make sure to transpose them first using:
`permutedims(arr, reverse(1:ndims(arr)))`
"""
return hlo_call(
arg1, arg2, arg3,
source=_hlo_code,
)
if __name__ == "__main__":
# Load the example inputs
inputs = load_inputs()
# Run the function (with JIT compilation)
print("Running my_function with JIT compilation...")
result = jax.jit(run_my_function)(*inputs)
print("Result:", result) |
|
Just as a heads up, I was blocked by some firewall rules while working on your feedback. Expand below for details. Warning Firewall rules blocked me from connecting to one or more addresses (expand for details)I tried to connect to the following addresses, but was blocked by firewall rules:
If you need me to access, download, or install something from one of these locations, you can either:
|
|
@avik-pal specifically for the GB debugging sake, could we also figure out how to pass the sharding info as well? |
should work now |
Co-authored-by: avik-pal <[email protected]>
Co-authored-by: avik-pal <[email protected]>
…tation Co-authored-by: avik-pal <[email protected]>
Co-authored-by: avik-pal <[email protected]>
Co-authored-by: wsmoses <[email protected]>
73c2efa to
68eaf8a
Compare
68eaf8a to
c86486d
Compare
e39cd36 to
ac13dff
Compare
Refactored
export_to_enzymejaxto use Reactant's existingmake_tracerinfrastructure as suggested by @wsmoses. Instead of manually creating an inverse map from traced args to concrete args, we now directly iterate overmlir_fn_res.seen_argswhich is an OrderedIdDict that already contains this mapping.Changes
invmapcreation (lines 88-91)mlir_fn_res.seen_argswhich maps concrete args to traced argsmlir_fn_res.linear_argsexport_to_enzymeaxtoexport_to_enzymejaxfor naming consistencyThis aligns with the design of Reactant's compilation infrastructure where
seen_argsis already populated during the tracing process.Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.