Skip to content

Commit ae75160

Browse files
Copilotavik-pal
andcommitted
Add export_to_enzymeax function for JAX integration
Co-authored-by: avik-pal <[email protected]>
1 parent a6757ca commit ae75160

File tree

3 files changed

+440
-0
lines changed

3 files changed

+440
-0
lines changed

src/serialization/EnzymeJAX.jl

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
module EnzymeJAX
2+
3+
using ..Reactant: Reactant, Compiler, MLIR
4+
5+
const NUMPY_SIMPLE_TYPES = Dict(
6+
Bool => "np.bool_",
7+
Int8 => "np.int8",
8+
Int16 => "np.int16",
9+
Int32 => "np.int32",
10+
Int64 => "np.int64",
11+
UInt8 => "np.uint8",
12+
UInt16 => "np.uint16",
13+
UInt32 => "np.uint32",
14+
UInt64 => "np.uint64",
15+
Float16 => "np.float16",
16+
Float32 => "np.float32",
17+
Float64 => "np.float64",
18+
ComplexF16 => "np.complex64", # Note: NumPy doesn't have float16 complex
19+
ComplexF32 => "np.complex64",
20+
ComplexF64 => "np.complex128",
21+
)
22+
23+
"""
24+
export_to_enzymeax(
25+
f,
26+
args...;
27+
output_dir::String=".",
28+
function_name::String="exported_function",
29+
)
30+
31+
Export a Julia function to EnzymeJAX format for use in Python/JAX.
32+
33+
This function:
34+
1. Compiles the function to StableHLO via `Reactant.@code_hlo`
35+
2. Saves the MLIR/StableHLO code to a `.mlir` file
36+
3. Saves input arrays to `.npy` files (transposed to account for row-major vs column-major)
37+
4. Generates a Python script with the function wrapped for EnzymeJAX's `hlo_call`
38+
39+
## Arguments
40+
41+
- `f`: The Julia function to export
42+
- `args...`: The arguments to the function (used to infer types and shapes)
43+
44+
## Keyword Arguments
45+
46+
- `output_dir::String="."`: Directory where output files will be saved
47+
- `function_name::String="exported_function"`: Base name for generated files
48+
49+
## Returns
50+
51+
A tuple `(mlir_path, python_path, input_paths)` containing paths to:
52+
- The generated `.mlir` file
53+
- The generated `.py` file
54+
- A vector of paths to input `.npy` files
55+
56+
## Example
57+
58+
```julia
59+
using Reactant
60+
61+
# Define a simple function
62+
function my_function(x, y)
63+
return x .+ y
64+
end
65+
66+
# Create some example inputs
67+
x = Reactant.to_rarray(Float32[1, 2, 3])
68+
y = Reactant.to_rarray(Float32[4, 5, 6])
69+
70+
# Export to EnzymeJAX
71+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
72+
my_function, x, y;
73+
output_dir="/tmp/exported",
74+
function_name="my_function"
75+
)
76+
```
77+
78+
Then in Python:
79+
```python
80+
# Run the generated Python script
81+
from exported.my_function import run_my_function
82+
import jax
83+
84+
result = jax.jit(run_my_function)(*inputs)
85+
```
86+
"""
87+
function export_to_enzymeax(
88+
f,
89+
args...;
90+
output_dir::String=".",
91+
function_name::String="exported_function",
92+
)
93+
# Create output directory if it doesn't exist
94+
mkpath(output_dir)
95+
96+
# Generate the StableHLO/MLIR code using compile_mlir directly
97+
mod, mlir_fn_res = Compiler.compile_mlir(
98+
f, args;
99+
shardy_passes=:none
100+
)
101+
hlo_code = string(mod)
102+
103+
# Save MLIR code
104+
mlir_path = joinpath(output_dir, "$(function_name).mlir")
105+
write(mlir_path, hlo_code)
106+
107+
# Process and save inputs
108+
input_paths = String[]
109+
input_info = []
110+
111+
for (i, arg) in enumerate(args)
112+
# Convert to array if needed
113+
arr = _to_array(arg)
114+
115+
# Save the input (transposed for row-major Python/NumPy)
116+
input_path = joinpath(output_dir, "$(function_name)_input_$(i).npy")
117+
_save_transposed_array(input_path, arr)
118+
push!(input_paths, input_path)
119+
120+
# Store shape and dtype info (in Julia's column-major ordering)
121+
push!(input_info, (shape=size(arr), dtype=eltype(arr)))
122+
end
123+
124+
# Generate Python script
125+
python_path = joinpath(output_dir, "$(function_name).py")
126+
_generate_python_script(python_path, function_name, mlir_path, input_paths, input_info)
127+
128+
return (mlir_path, python_path, input_paths)
129+
end
130+
131+
"""
132+
Convert Reactant types to regular Julia arrays for saving.
133+
"""
134+
function _to_array(x::Reactant.ConcreteRArray)
135+
return Array(x)
136+
end
137+
138+
function _to_array(x::Reactant.ConcreteRNumber)
139+
return [x.data]
140+
end
141+
142+
function _to_array(x::AbstractArray)
143+
return Array(x)
144+
end
145+
146+
function _to_array(x::Number)
147+
return [x]
148+
end
149+
150+
function _to_array(x::Tuple)
151+
error("Tuple arguments are not yet supported. Please flatten your arguments.")
152+
end
153+
154+
function _to_array(x::NamedTuple)
155+
error("NamedTuple arguments are not yet supported. Please flatten your arguments.")
156+
end
157+
158+
"""
159+
Save an array to a .npy file, transposing to account for row-major vs column-major ordering.
160+
"""
161+
function _save_transposed_array(path::String, arr::AbstractArray)
162+
# For multi-dimensional arrays, we need to reverse the dimensions for Python/NumPy
163+
# Julia: column-major (fastest changing index first)
164+
# Python: row-major (fastest changing index last)
165+
transposed = permutedims(arr, reverse(1:ndims(arr)))
166+
167+
# Use a simple .npy writer
168+
# NPY format v1.0: magic (6 bytes) + version (2 bytes) + header_len (2 bytes) + header + data
169+
open(path, "w") do io
170+
# Magic number for .npy format
171+
write(io, UInt8[0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59])
172+
# Version 1.0
173+
write(io, UInt8[0x01, 0x00])
174+
175+
# Prepare header
176+
dtype_str = _numpy_dtype_string(eltype(arr))
177+
shape_str = join(size(transposed), ", ")
178+
header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}"
179+
180+
# Pad header to be aligned on 64 bytes
181+
header_len = length(header) + 1 # +1 for newline
182+
total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2)
183+
padding = (64 - (total_len % 64)) % 64
184+
header = header * " "^padding * "\n"
185+
header_len = length(header)
186+
187+
# Write header length (little-endian UInt16)
188+
write(io, UInt16(header_len))
189+
# Write header
190+
write(io, header)
191+
# Write data
192+
write(io, vec(transposed))
193+
end
194+
end
195+
196+
"""
197+
Get NumPy dtype string for a Julia type.
198+
"""
199+
function _numpy_dtype_string(::Type{Bool})
200+
return "|b1"
201+
end
202+
203+
function _numpy_dtype_string(::Type{Int8})
204+
return "|i1"
205+
end
206+
207+
function _numpy_dtype_string(::Type{UInt8})
208+
return "|u1"
209+
end
210+
211+
function _numpy_dtype_string(::Type{Int16})
212+
return "<i2"
213+
end
214+
215+
function _numpy_dtype_string(::Type{UInt16})
216+
return "<u2"
217+
end
218+
219+
function _numpy_dtype_string(::Type{Int32})
220+
return "<i4"
221+
end
222+
223+
function _numpy_dtype_string(::Type{UInt32})
224+
return "<u4"
225+
end
226+
227+
function _numpy_dtype_string(::Type{Int64})
228+
return "<i8"
229+
end
230+
231+
function _numpy_dtype_string(::Type{UInt64})
232+
return "<u8"
233+
end
234+
235+
function _numpy_dtype_string(::Type{Float16})
236+
return "<f2"
237+
end
238+
239+
function _numpy_dtype_string(::Type{Float32})
240+
return "<f4"
241+
end
242+
243+
function _numpy_dtype_string(::Type{Float64})
244+
return "<f8"
245+
end
246+
247+
function _numpy_dtype_string(::Type{ComplexF32})
248+
return "<c8"
249+
end
250+
251+
function _numpy_dtype_string(::Type{ComplexF64})
252+
return "<c16"
253+
end
254+
255+
"""
256+
Generate a Python script that uses EnzymeJAX to call the exported function.
257+
"""
258+
function _generate_python_script(
259+
python_path::String,
260+
function_name::String,
261+
mlir_path::String,
262+
input_paths::Vector{String},
263+
input_info::Vector,
264+
)
265+
# Get relative paths for the Python script
266+
output_dir = dirname(python_path)
267+
mlir_rel = relpath(mlir_path, output_dir)
268+
input_rels = [relpath(p, output_dir) for p in input_paths]
269+
270+
# Start building the Python script
271+
script = """
272+
\"\"\"
273+
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
274+
275+
This script was generated by Reactant.Serialization.export_to_enzymeax().
276+
\"\"\"
277+
278+
from enzyme_ad.jax import hlo_call
279+
import jax
280+
import jax.numpy as jnp
281+
import numpy as np
282+
import os
283+
284+
# Get the directory of this script
285+
_script_dir = os.path.dirname(os.path.abspath(__file__))
286+
287+
# Load the MLIR/StableHLO code
288+
with open(os.path.join(_script_dir, "$(mlir_rel)"), "r") as f:
289+
_hlo_code = f.read()
290+
291+
"""
292+
293+
# Add function to load inputs
294+
script *= """
295+
def load_inputs():
296+
\"\"\"Load the example inputs that were exported from Julia.\"\"\"
297+
inputs = []
298+
"""
299+
300+
for (i, input_rel) in enumerate(input_rels)
301+
script *= """
302+
inputs.append(np.load(os.path.join(_script_dir, "$(input_rel)")))
303+
"""
304+
end
305+
306+
script *= """
307+
return tuple(inputs)
308+
309+
"""
310+
311+
# Add the main function that calls the HLO code
312+
arg_names = ["arg$i" for i in 1:length(input_paths)]
313+
arg_list = join(arg_names, ", ")
314+
315+
script *= """
316+
def run_$(function_name)($(arg_list)):
317+
\"\"\"
318+
Call the exported Julia function via EnzymeJAX.
319+
320+
Args:
321+
"""
322+
323+
for (i, info) in enumerate(input_info)
324+
# Note: shapes are already transposed for Python
325+
python_shape = reverse(info.shape)
326+
script *= """
327+
$(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])
328+
"""
329+
end
330+
331+
script *= """
332+
333+
Returns:
334+
The result of calling the exported function.
335+
336+
Note:
337+
All inputs must be in row-major (Python/NumPy) order. If you're passing
338+
arrays from Julia, make sure to transpose them first using:
339+
`permutedims(arr, reverse(1:ndims(arr)))`
340+
\"\"\"
341+
return hlo_call(
342+
$(arg_list),
343+
source=_hlo_code,
344+
)
345+
346+
"""
347+
348+
# Add a main block for testing
349+
script *= """
350+
if __name__ == "__main__":
351+
# Load the example inputs
352+
inputs = load_inputs()
353+
354+
# Run the function (with JIT compilation)
355+
print("Running $(function_name) with JIT compilation...")
356+
result = jax.jit(run_$(function_name))(*inputs)
357+
print("Result:", result)
358+
print("Result shape:", result.shape if hasattr(result, 'shape') else 'scalar')
359+
print("Result dtype:", result.dtype if hasattr(result, 'dtype') else type(result))
360+
"""
361+
362+
write(python_path, script)
363+
end
364+
365+
end # module

0 commit comments

Comments
 (0)