Skip to content

Commit 7e24c31

Browse files
Copilotavik-pal
andcommitted
Fix code review issues: binary mode, ComplexF16 support, Python indentation
Co-authored-by: avik-pal <[email protected]>
1 parent 04b97ae commit 7e24c31

File tree

2 files changed

+70
-81
lines changed

2 files changed

+70
-81
lines changed

src/serialization/EnzymeJAX.jl

Lines changed: 69 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ const NUMPY_SIMPLE_TYPES = Dict(
1515
Float16 => "np.float16",
1616
Float32 => "np.float32",
1717
Float64 => "np.float64",
18-
ComplexF16 => "np.complex64", # Note: NumPy doesn't have float16 complex
1918
ComplexF32 => "np.complex64",
2019
ComplexF64 => "np.complex128",
2120
)
@@ -177,10 +176,11 @@ function _save_transposed_array(path::String, arr::AbstractArray)
177176
shape_str = join(size(transposed), ", ")
178177
header = "{'descr': '$(dtype_str)', 'fortran_order': False, 'shape': ($(shape_str),)}"
179178

180-
# Pad header to be aligned on 64 bytes
179+
# Pad header to be aligned on 64 bytes (16-byte alignment for v1.0)
180+
# Total size needs to be divisible by 16
181181
header_len = length(header) + 1 # +1 for newline
182182
total_len = 10 + header_len # 10 = magic(6) + version(2) + header_len(2)
183-
padding = (64 - (total_len % 64)) % 64
183+
padding = (16 - (total_len % 16)) % 16
184184
header = header * " "^padding * "\n"
185185
header_len = length(header)
186186

@@ -191,6 +191,7 @@ function _save_transposed_array(path::String, arr::AbstractArray)
191191
# Write data
192192
write(io, vec(transposed))
193193
end
194+
return nothing
194195
end
195196

196197
"""
@@ -267,99 +268,86 @@ function _generate_python_script(
267268
mlir_rel = relpath(mlir_path, output_dir)
268269
input_rels = [relpath(p, output_dir) for p in input_paths]
269270

270-
# Start building the Python script
271-
script = """
272-
\"\"\"
273-
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
271+
# Build the Python script without leading indentation
272+
lines = String[]
274273

275-
This script was generated by Reactant.Serialization.export_to_enzymeax().
276-
\"\"\"
274+
# Header
275+
push!(lines, "\"\"\"")
276+
push!(lines, "Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.")
277+
push!(lines, "")
278+
push!(lines, "This script was generated by Reactant.Serialization.export_to_enzymeax().")
279+
push!(lines, "\"\"\"")
280+
push!(lines, "")
281+
push!(lines, "from enzyme_ad.jax import hlo_call")
282+
push!(lines, "import jax")
283+
push!(lines, "import jax.numpy as jnp")
284+
push!(lines, "import numpy as np")
285+
push!(lines, "import os")
286+
push!(lines, "")
287+
push!(lines, "# Get the directory of this script")
288+
push!(lines, "_script_dir = os.path.dirname(os.path.abspath(__file__))")
289+
push!(lines, "")
290+
push!(lines, "# Load the MLIR/StableHLO code")
291+
push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:")
292+
push!(lines, " _hlo_code = f.read()")
293+
push!(lines, "")
277294

278-
from enzyme_ad.jax import hlo_call
279-
import jax
280-
import jax.numpy as jnp
281-
import numpy as np
282-
import os
283-
284-
# Get the directory of this script
285-
_script_dir = os.path.dirname(os.path.abspath(__file__))
286-
287-
# Load the MLIR/StableHLO code
288-
with open(os.path.join(_script_dir, "$(mlir_rel)"), "r") as f:
289-
_hlo_code = f.read()
290-
291-
"""
292-
293-
# Add function to load inputs
294-
script *= """
295-
def load_inputs():
296-
\"\"\"Load the example inputs that were exported from Julia.\"\"\"
297-
inputs = []
298-
"""
299-
300-
for (i, input_rel) in enumerate(input_rels)
301-
script *= """
302-
inputs.append(np.load(os.path.join(_script_dir, "$(input_rel)")))
303-
"""
295+
# Function to load inputs
296+
push!(lines, "def load_inputs():")
297+
push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"")
298+
push!(lines, " inputs = []")
299+
for input_rel in input_rels
300+
push!(lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))")
304301
end
302+
push!(lines, " return tuple(inputs)")
303+
push!(lines, "")
305304

306-
script *= """
307-
return tuple(inputs)
308-
309-
"""
310-
311-
# Add the main function that calls the HLO code
305+
# Main function
312306
arg_names = ["arg$i" for i in 1:length(input_paths)]
313307
arg_list = join(arg_names, ", ")
314308

315-
script *= """
316-
def run_$(function_name)($(arg_list)):
317-
\"\"\"
318-
Call the exported Julia function via EnzymeJAX.
319-
320-
Args:
321-
"""
309+
push!(lines, "def run_$(function_name)($(arg_list)):")
310+
push!(lines, " \"\"\"")
311+
push!(lines, " Call the exported Julia function via EnzymeJAX.")
312+
push!(lines, " ")
313+
push!(lines, " Args:")
322314

323315
for (i, info) in enumerate(input_info)
324316
# Note: shapes are already transposed for Python
325317
python_shape = reverse(info.shape)
326-
script *= """
327-
$(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])
328-
"""
318+
push!(lines, " $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])")
329319
end
330320

331-
script *= """
332-
333-
Returns:
334-
The result of calling the exported function.
335-
336-
Note:
337-
All inputs must be in row-major (Python/NumPy) order. If you're passing
338-
arrays from Julia, make sure to transpose them first using:
339-
`permutedims(arr, reverse(1:ndims(arr)))`
340-
\"\"\"
341-
return hlo_call(
342-
$(arg_list),
343-
source=_hlo_code,
344-
)
345-
346-
"""
321+
push!(lines, " ")
322+
push!(lines, " Returns:")
323+
push!(lines, " The result of calling the exported function.")
324+
push!(lines, " ")
325+
push!(lines, " Note:")
326+
push!(lines, " All inputs must be in row-major (Python/NumPy) order. If you're passing")
327+
push!(lines, " arrays from Julia, make sure to transpose them first using:")
328+
push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`")
329+
push!(lines, " \"\"\"")
330+
push!(lines, " return hlo_call(")
331+
push!(lines, " $(arg_list),")
332+
push!(lines, " source=_hlo_code,")
333+
push!(lines, " )")
334+
push!(lines, "")
347335

348-
# Add a main block for testing
349-
script *= """
350-
if __name__ == "__main__":
351-
# Load the example inputs
352-
inputs = load_inputs()
353-
354-
# Run the function (with JIT compilation)
355-
print("Running $(function_name) with JIT compilation...")
356-
result = jax.jit(run_$(function_name))(*inputs)
357-
print("Result:", result)
358-
print("Result shape:", result.shape if hasattr(result, 'shape') else 'scalar')
359-
print("Result dtype:", result.dtype if hasattr(result, 'dtype') else type(result))
360-
"""
336+
# Main block
337+
push!(lines, "if __name__ == \"__main__\":")
338+
push!(lines, " # Load the example inputs")
339+
push!(lines, " inputs = load_inputs()")
340+
push!(lines, " ")
341+
push!(lines, " # Run the function (with JIT compilation)")
342+
push!(lines, " print(\"Running $(function_name) with JIT compilation...\")")
343+
push!(lines, " result = jax.jit(run_$(function_name))(*inputs)")
344+
push!(lines, " print(\"Result:\", result)")
345+
push!(lines, " print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')")
346+
push!(lines, " print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))")
361347

362-
write(python_path, script)
348+
# Write the script
349+
write(python_path, join(lines, "\n") * "\n")
350+
return nothing
363351
end
364352

365353
end # module

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
3838
@safetestset "Cluster Detection" include("cluster_detector.jl")
3939
@safetestset "Config" include("config.jl")
4040
@safetestset "Batching" include("batching.jl")
41+
@safetestset "Export to EnzymeJAX" include("export_enzymeax.jl")
4142
@safetestset "QA" include("qa.jl")
4243
end
4344

0 commit comments

Comments
 (0)