@@ -15,7 +15,6 @@ const NUMPY_SIMPLE_TYPES = Dict(
1515 Float16 => " np.float16" ,
1616 Float32 => " np.float32" ,
1717 Float64 => " np.float64" ,
18- ComplexF16 => " np.complex64" , # Note: NumPy doesn't have float16 complex
1918 ComplexF32 => " np.complex64" ,
2019 ComplexF64 => " np.complex128" ,
2120)
@@ -177,10 +176,11 @@ function _save_transposed_array(path::String, arr::AbstractArray)
177176 shape_str = join (size (transposed), " , " )
178177 header = " {'descr': '$(dtype_str) ', 'fortran_order': False, 'shape': ($(shape_str) ,)}"
179178
180- # Pad header to be aligned on 64 bytes
179+ # Pad header to be aligned on 64 bytes (16-byte alignment for v1.0)
180+ # Total size needs to be divisible by 16
181181 header_len = length (header) + 1 # +1 for newline
182182 total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2)
183- padding = (64 - (total_len % 64 )) % 64
183+ padding = (16 - (total_len % 16 )) % 16
184184 header = header * " " ^ padding * " \n "
185185 header_len = length (header)
186186
@@ -191,6 +191,7 @@ function _save_transposed_array(path::String, arr::AbstractArray)
191191 # Write data
192192 write (io, vec (transposed))
193193 end
194+ return nothing
194195end
195196
196197"""
@@ -267,99 +268,86 @@ function _generate_python_script(
267268 mlir_rel = relpath (mlir_path, output_dir)
268269 input_rels = [relpath (p, output_dir) for p in input_paths]
269270
270- # Start building the Python script
271- script = """
272- \"\"\"
273- Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
271+ # Build the Python script without leading indentation
272+ lines = String[]
274273
275- This script was generated by Reactant.Serialization.export_to_enzymeax().
276- \"\"\"
274+ # Header
275+ push! (lines, " \"\"\" " )
276+ push! (lines, " Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX." )
277+ push! (lines, " " )
278+ push! (lines, " This script was generated by Reactant.Serialization.export_to_enzymeax()." )
279+ push! (lines, " \"\"\" " )
280+ push! (lines, " " )
281+ push! (lines, " from enzyme_ad.jax import hlo_call" )
282+ push! (lines, " import jax" )
283+ push! (lines, " import jax.numpy as jnp" )
284+ push! (lines, " import numpy as np" )
285+ push! (lines, " import os" )
286+ push! (lines, " " )
287+ push! (lines, " # Get the directory of this script" )
288+ push! (lines, " _script_dir = os.path.dirname(os.path.abspath(__file__))" )
289+ push! (lines, " " )
290+ push! (lines, " # Load the MLIR/StableHLO code" )
291+ push! (lines, " with open(os.path.join(_script_dir, \" $(mlir_rel) \" ), \" r\" ) as f:" )
292+ push! (lines, " _hlo_code = f.read()" )
293+ push! (lines, " " )
277294
278- from enzyme_ad.jax import hlo_call
279- import jax
280- import jax.numpy as jnp
281- import numpy as np
282- import os
283-
284- # Get the directory of this script
285- _script_dir = os.path.dirname(os.path.abspath(__file__))
286-
287- # Load the MLIR/StableHLO code
288- with open(os.path.join(_script_dir, "$(mlir_rel) "), "r") as f:
289- _hlo_code = f.read()
290-
291- """
292-
293- # Add function to load inputs
294- script *= """
295- def load_inputs():
296- \"\"\" Load the example inputs that were exported from Julia.\"\"\"
297- inputs = []
298- """
299-
300- for (i, input_rel) in enumerate (input_rels)
301- script *= """
302- inputs.append(np.load(os.path.join(_script_dir, "$(input_rel) ")))
303- """
295+ # Function to load inputs
296+ push! (lines, " def load_inputs():" )
297+ push! (lines, " \"\"\" Load the example inputs that were exported from Julia.\"\"\" " )
298+ push! (lines, " inputs = []" )
299+ for input_rel in input_rels
300+ push! (lines, " inputs.append(np.load(os.path.join(_script_dir, \" $(input_rel) \" )))" )
304301 end
302+ push! (lines, " return tuple(inputs)" )
303+ push! (lines, " " )
305304
306- script *= """
307- return tuple(inputs)
308-
309- """
310-
311- # Add the main function that calls the HLO code
305+ # Main function
312306 arg_names = [" arg$i " for i in 1 : length (input_paths)]
313307 arg_list = join (arg_names, " , " )
314308
315- script *= """
316- def run_$(function_name) ($(arg_list) ):
317- \"\"\"
318- Call the exported Julia function via EnzymeJAX.
319-
320- Args:
321- """
309+ push! (lines, " def run_$(function_name) ($(arg_list) ):" )
310+ push! (lines, " \"\"\" " )
311+ push! (lines, " Call the exported Julia function via EnzymeJAX." )
312+ push! (lines, " " )
313+ push! (lines, " Args:" )
322314
323315 for (i, info) in enumerate (input_info)
324316 # Note: shapes are already transposed for Python
325317 python_shape = reverse (info. shape)
326- script *= """
327- $(arg_names[i]) : Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info. dtype])
328- """
318+ push! (lines, " $(arg_names[i]) : Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info. dtype]) " )
329319 end
330320
331- script *= """
332-
333- Returns:
334- The result of calling the exported function.
335-
336- Note:
337- All inputs must be in row-major (Python/NumPy) order. If you're passing
338- arrays from Julia, make sure to transpose them first using:
339- `permutedims(arr, reverse(1:ndims(arr)))`
340- \"\"\"
341- return hlo_call(
342- $(arg_list) ,
343- source=_hlo_code,
344- )
345-
346- """
321+ push! (lines, " " )
322+ push! (lines, " Returns:" )
323+ push! (lines, " The result of calling the exported function." )
324+ push! (lines, " " )
325+ push! (lines, " Note:" )
326+ push! (lines, " All inputs must be in row-major (Python/NumPy) order. If you're passing" )
327+ push! (lines, " arrays from Julia, make sure to transpose them first using:" )
328+ push! (lines, " `permutedims(arr, reverse(1:ndims(arr)))`" )
329+ push! (lines, " \"\"\" " )
330+ push! (lines, " return hlo_call(" )
331+ push! (lines, " $(arg_list) ," )
332+ push! (lines, " source=_hlo_code," )
333+ push! (lines, " )" )
334+ push! (lines, " " )
347335
348- # Add a main block for testing
349- script *= """
350- if __name__ == "__main__":
351- # Load the example inputs
352- inputs = load_inputs()
353-
354- # Run the function (with JIT compilation)
355- print("Running $(function_name) with JIT compilation...")
356- result = jax.jit(run_$(function_name) )(*inputs)
357- print("Result:", result)
358- print("Result shape:", result.shape if hasattr(result, 'shape') else 'scalar')
359- print("Result dtype:", result.dtype if hasattr(result, 'dtype') else type(result))
360- """
336+ # Main block
337+ push! (lines, " if __name__ == \" __main__\" :" )
338+ push! (lines, " # Load the example inputs" )
339+ push! (lines, " inputs = load_inputs()" )
340+ push! (lines, " " )
341+ push! (lines, " # Run the function (with JIT compilation)" )
342+ push! (lines, " print(\" Running $(function_name) with JIT compilation...\" )" )
343+ push! (lines, " result = jax.jit(run_$(function_name) )(*inputs)" )
344+ push! (lines, " print(\" Result:\" , result)" )
345+ push! (lines, " print(\" Result shape:\" , result.shape if hasattr(result, 'shape') else 'scalar')" )
346+ push! (lines, " print(\" Result dtype:\" , result.dtype if hasattr(result, 'dtype') else type(result))" )
361347
362- write (python_path, script)
348+ # Write the script
349+ write (python_path, join (lines, " \n " ) * " \n " )
350+ return nothing
363351end
364352
365353end # module
0 commit comments