|
1 | 1 | module EnzymeJAX |
2 | 2 |
|
3 | | -using ..Reactant: Reactant, Compiler, MLIR |
4 | | - |
5 | | -const NUMPY_SIMPLE_TYPES = Dict( |
6 | | - Bool => "np.bool_", |
7 | | - Int8 => "np.int8", |
8 | | - Int16 => "np.int16", |
9 | | - Int32 => "np.int32", |
10 | | - Int64 => "np.int64", |
11 | | - UInt8 => "np.uint8", |
12 | | - UInt16 => "np.uint16", |
13 | | - UInt32 => "np.uint32", |
14 | | - UInt64 => "np.uint64", |
15 | | - Float16 => "np.float16", |
16 | | - Float32 => "np.float32", |
17 | | - Float64 => "np.float64", |
18 | | - ComplexF32 => "np.complex64", |
19 | | - ComplexF64 => "np.complex128", |
20 | | -) |
| 3 | +using ..Reactant: Reactant, Compiler, MLIR, Serialization |
21 | 4 |
|
22 | 5 | """ |
23 | 6 | export_to_enzymeax( |
@@ -264,104 +247,88 @@ function _generate_python_script( |
264 | 247 | mlir_rel = relpath(mlir_path, output_dir) |
265 | 248 | input_rels = [relpath(p, output_dir) for p in input_paths] |
266 | 249 |
|
267 | | - # Build the Python script without leading indentation |
268 | | - lines = String[] |
269 | | - |
270 | | - # Header |
271 | | - push!(lines, "\"\"\"") |
272 | | - push!( |
273 | | - lines, |
274 | | - "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.", |
275 | | - ) |
276 | | - push!(lines, "") |
277 | | - push!( |
278 | | - lines, "This script was generated by Reactant.Serialization.export_to_enzymeax()." |
| 250 | + # Generate input loading code |
| 251 | + input_loads = join( |
| 252 | + [ |
| 253 | + " inputs.append(np.load(os.path.join(_script_dir, \"$rel\")))" for |
| 254 | + rel in input_rels |
| 255 | + ], |
| 256 | + "\n", |
279 | 257 | ) |
280 | | - push!(lines, "\"\"\"") |
281 | | - push!(lines, "") |
282 | | - push!(lines, "from enzyme_ad.jax import hlo_call") |
283 | | - push!(lines, "import jax") |
284 | | - push!(lines, "import jax.numpy as jnp") |
285 | | - push!(lines, "import numpy as np") |
286 | | - push!(lines, "import os") |
287 | | - push!(lines, "") |
288 | | - push!(lines, "# Get the directory of this script") |
289 | | - push!(lines, "_script_dir = os.path.dirname(os.path.abspath(__file__))") |
290 | | - push!(lines, "") |
291 | | - push!(lines, "# Load the MLIR/StableHLO code") |
292 | | - push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:") |
293 | | - push!(lines, " _hlo_code = f.read()") |
294 | | - push!(lines, "") |
295 | | - |
296 | | - # Function to load inputs |
297 | | - push!(lines, "def load_inputs():") |
298 | | - push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"") |
299 | | - push!(lines, " inputs = []") |
300 | | - for input_rel in input_rels |
301 | | - push!( |
302 | | - lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))" |
303 | | - ) |
304 | | - end |
305 | | - push!(lines, " return tuple(inputs)") |
306 | | - push!(lines, "") |
307 | 258 |
|
308 | | - # Main function |
| 259 | + # Generate argument list and documentation |
309 | 260 | arg_names = ["arg$i" for i in 1:length(input_paths)] |
310 | 261 | arg_list = join(arg_names, ", ") |
311 | 262 |
|
312 | | - push!(lines, "def run_$(function_name)($(arg_list)):") |
313 | | - push!(lines, " \"\"\"") |
314 | | - push!(lines, " Call the exported Julia function via EnzymeJAX.") |
315 | | - push!(lines, " ") |
316 | | - push!(lines, " Args:") |
317 | | - |
318 | | - for (i, info) in enumerate(input_info) |
319 | | - # Note: shapes are already transposed for Python |
320 | | - python_shape = reverse(info.shape) |
321 | | - push!( |
322 | | - lines, |
323 | | - " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])", |
| 263 | + # Generate docstring for arguments |
| 264 | + arg_docs = join( |
| 265 | + [ |
| 266 | + " $(arg_names[i]): Array of shape $(reverse(info.shape)) and dtype $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])" |
| 267 | + for (i, info) in enumerate(input_info) |
| 268 | + ], |
| 269 | + "\n", |
| 270 | + ) |
| 271 | + |
| 272 | + # Build the complete Python script |
| 273 | + script = """ |
| 274 | + \"\"\" |
| 275 | + Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX. |
| 276 | +
|
| 277 | + This script was generated by Reactant.Serialization.export_to_enzymeax(). |
| 278 | + \"\"\" |
| 279 | +
|
| 280 | + from enzyme_ad.jax import hlo_call |
| 281 | + import jax |
| 282 | + import jax.numpy as jnp |
| 283 | + import numpy as np |
| 284 | + import os |
| 285 | +
|
| 286 | + # Get the directory of this script |
| 287 | + _script_dir = os.path.dirname(os.path.abspath(__file__)) |
| 288 | +
|
| 289 | + # Load the MLIR/StableHLO code |
| 290 | + with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f: |
| 291 | + _hlo_code = f.read() |
| 292 | +
|
| 293 | + def load_inputs(): |
| 294 | + \"\"\"Load the example inputs that were exported from Julia.\"\"\" |
| 295 | + inputs = [] |
| 296 | + $input_loads |
| 297 | + return tuple(inputs) |
| 298 | +
|
| 299 | + def run_$(function_name)($(arg_list)): |
| 300 | + \"\"\" |
| 301 | + Call the exported Julia function via EnzymeJAX. |
| 302 | +
|
| 303 | + Args: |
| 304 | + $arg_docs |
| 305 | +
|
| 306 | + Returns: |
| 307 | + The result of calling the exported function. |
| 308 | +
|
| 309 | + Note: |
| 310 | + All inputs must be in row-major (Python/NumPy) order. If you're passing |
| 311 | + arrays from Julia, make sure to transpose them first using: |
| 312 | + \`permutedims(arr, reverse(1:ndims(arr)))\` |
| 313 | + \"\"\" |
| 314 | + return hlo_call( |
| 315 | + $(arg_list), |
| 316 | + source=_hlo_code, |
324 | 317 | ) |
325 | | - end |
326 | 318 |
|
327 | | - push!(lines, " ") |
328 | | - push!(lines, " Returns:") |
329 | | - push!(lines, " The result of calling the exported function.") |
330 | | - push!(lines, " ") |
331 | | - push!(lines, " Note:") |
332 | | - push!( |
333 | | - lines, |
334 | | - " All inputs must be in row-major (Python/NumPy) order. If you're passing", |
335 | | - ) |
336 | | - push!(lines, " arrays from Julia, make sure to transpose them first using:") |
337 | | - push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`") |
338 | | - push!(lines, " \"\"\"") |
339 | | - push!(lines, " return hlo_call(") |
340 | | - push!(lines, " $(arg_list),") |
341 | | - push!(lines, " source=_hlo_code,") |
342 | | - push!(lines, " )") |
343 | | - push!(lines, "") |
344 | | - |
345 | | - # Main block |
346 | | - push!(lines, "if __name__ == \"__main__\":") |
347 | | - push!(lines, " # Load the example inputs") |
348 | | - push!(lines, " inputs = load_inputs()") |
349 | | - push!(lines, " ") |
350 | | - push!(lines, " # Run the function (with JIT compilation)") |
351 | | - push!(lines, " print(\"Running $(function_name) with JIT compilation...\")") |
352 | | - push!(lines, " result = jax.jit(run_$(function_name))(*inputs)") |
353 | | - push!(lines, " print(\"Result:\", result)") |
354 | | - push!( |
355 | | - lines, |
356 | | - " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')", |
357 | | - ) |
358 | | - push!( |
359 | | - lines, |
360 | | - " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))", |
361 | | - ) |
| 319 | + if __name__ == \"__main__\": |
| 320 | + # Load the example inputs |
| 321 | + inputs = load_inputs() |
| 322 | +
|
| 323 | + # Run the function (with JIT compilation) |
| 324 | + print(\"Running $(function_name) with JIT compilation...\") |
| 325 | + result = jax.jit(run_$(function_name))(*inputs) |
| 326 | + print(\"Result:\", result) |
| 327 | + print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar') |
| 328 | + print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result)) |
| 329 | + """ |
362 | 330 |
|
363 | | - # Write the script |
364 | | - write(python_path, join(lines, "\n") * "\n") |
| 331 | + write(python_path, strip(script) * "\n") |
365 | 332 | return nothing |
366 | 333 | end |
367 | 334 |
|
|
0 commit comments