Skip to content

Commit 2110340

Browse files
committed
chore: cleanup
1 parent c8127e2 commit 2110340

File tree

4 files changed

+95
-145
lines changed

4 files changed

+95
-145
lines changed

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReactantPythonCallExt
33
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
44
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
55
using Reactant.Ops: @opcall
6+
using Reactant.Serialization: NUMPY_SIMPLE_TYPES
67

78
const jaxptr = Ref{Py}()
89
const jnpptr = Ref{Py}()
@@ -15,24 +16,6 @@ const npptr = Ref{Py}()
1516

1617
const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)
1718

18-
const NUMPY_SIMPLE_TYPES = Dict(
19-
Bool => :bool,
20-
Int8 => :int8,
21-
Int16 => :int16,
22-
Int32 => :int32,
23-
Int64 => :int64,
24-
UInt8 => :uint8,
25-
UInt16 => :uint16,
26-
UInt32 => :uint32,
27-
UInt64 => :uint64,
28-
Float16 => :float16,
29-
Float32 => :float32,
30-
Float64 => :float64,
31-
ComplexF16 => :complex16,
32-
ComplexF32 => :complex32,
33-
ComplexF64 => :complex64,
34-
)
35-
3619
function __init__()
3720
try
3821
jaxptr[] = pyimport("jax")

src/serialization/EnzymeJAX.jl

Lines changed: 75 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,6 @@
11
module EnzymeJAX
22

3-
using ..Reactant: Reactant, Compiler, MLIR
4-
5-
const NUMPY_SIMPLE_TYPES = Dict(
6-
Bool => "np.bool_",
7-
Int8 => "np.int8",
8-
Int16 => "np.int16",
9-
Int32 => "np.int32",
10-
Int64 => "np.int64",
11-
UInt8 => "np.uint8",
12-
UInt16 => "np.uint16",
13-
UInt32 => "np.uint32",
14-
UInt64 => "np.uint64",
15-
Float16 => "np.float16",
16-
Float32 => "np.float32",
17-
Float64 => "np.float64",
18-
ComplexF32 => "np.complex64",
19-
ComplexF64 => "np.complex128",
20-
)
3+
using ..Reactant: Reactant, Compiler, MLIR, Serialization
214

225
"""
236
export_to_enzymeax(
@@ -264,104 +247,88 @@ function _generate_python_script(
264247
mlir_rel = relpath(mlir_path, output_dir)
265248
input_rels = [relpath(p, output_dir) for p in input_paths]
266249

267-
# Build the Python script without leading indentation
268-
lines = String[]
269-
270-
# Header
271-
push!(lines, "\"\"\"")
272-
push!(
273-
lines,
274-
"Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.",
275-
)
276-
push!(lines, "")
277-
push!(
278-
lines, "This script was generated by Reactant.Serialization.export_to_enzymeax()."
250+
# Generate input loading code
251+
input_loads = join(
252+
[
253+
" inputs.append(np.load(os.path.join(_script_dir, \"$rel\")))" for
254+
rel in input_rels
255+
],
256+
"\n",
279257
)
280-
push!(lines, "\"\"\"")
281-
push!(lines, "")
282-
push!(lines, "from enzyme_ad.jax import hlo_call")
283-
push!(lines, "import jax")
284-
push!(lines, "import jax.numpy as jnp")
285-
push!(lines, "import numpy as np")
286-
push!(lines, "import os")
287-
push!(lines, "")
288-
push!(lines, "# Get the directory of this script")
289-
push!(lines, "_script_dir = os.path.dirname(os.path.abspath(__file__))")
290-
push!(lines, "")
291-
push!(lines, "# Load the MLIR/StableHLO code")
292-
push!(lines, "with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:")
293-
push!(lines, " _hlo_code = f.read()")
294-
push!(lines, "")
295-
296-
# Function to load inputs
297-
push!(lines, "def load_inputs():")
298-
push!(lines, " \"\"\"Load the example inputs that were exported from Julia.\"\"\"")
299-
push!(lines, " inputs = []")
300-
for input_rel in input_rels
301-
push!(
302-
lines, " inputs.append(np.load(os.path.join(_script_dir, \"$(input_rel)\")))"
303-
)
304-
end
305-
push!(lines, " return tuple(inputs)")
306-
push!(lines, "")
307258

308-
# Main function
259+
# Generate argument list and documentation
309260
arg_names = ["arg$i" for i in 1:length(input_paths)]
310261
arg_list = join(arg_names, ", ")
311262

312-
push!(lines, "def run_$(function_name)($(arg_list)):")
313-
push!(lines, " \"\"\"")
314-
push!(lines, " Call the exported Julia function via EnzymeJAX.")
315-
push!(lines, " ")
316-
push!(lines, " Args:")
317-
318-
for (i, info) in enumerate(input_info)
319-
# Note: shapes are already transposed for Python
320-
python_shape = reverse(info.shape)
321-
push!(
322-
lines,
323-
" $(arg_names[i]): Array of shape $(python_shape) and dtype $(NUMPY_SIMPLE_TYPES[info.dtype])",
263+
# Generate docstring for arguments
264+
arg_docs = join(
265+
[
266+
" $(arg_names[i]): Array of shape $(reverse(info.shape)) and dtype $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])"
267+
for (i, info) in enumerate(input_info)
268+
],
269+
"\n",
270+
)
271+
272+
# Build the complete Python script
273+
script = """
274+
\"\"\"
275+
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
276+
277+
This script was generated by Reactant.Serialization.export_to_enzymeax().
278+
\"\"\"
279+
280+
from enzyme_ad.jax import hlo_call
281+
import jax
282+
import jax.numpy as jnp
283+
import numpy as np
284+
import os
285+
286+
# Get the directory of this script
287+
_script_dir = os.path.dirname(os.path.abspath(__file__))
288+
289+
# Load the MLIR/StableHLO code
290+
with open(os.path.join(_script_dir, \"$(mlir_rel)\"), \"r\") as f:
291+
_hlo_code = f.read()
292+
293+
def load_inputs():
294+
\"\"\"Load the example inputs that were exported from Julia.\"\"\"
295+
inputs = []
296+
$input_loads
297+
return tuple(inputs)
298+
299+
def run_$(function_name)($(arg_list)):
300+
\"\"\"
301+
Call the exported Julia function via EnzymeJAX.
302+
303+
Args:
304+
$arg_docs
305+
306+
Returns:
307+
The result of calling the exported function.
308+
309+
Note:
310+
All inputs must be in row-major (Python/NumPy) order. If you're passing
311+
arrays from Julia, make sure to transpose them first using:
312+
\`permutedims(arr, reverse(1:ndims(arr)))\`
313+
\"\"\"
314+
return hlo_call(
315+
$(arg_list),
316+
source=_hlo_code,
324317
)
325-
end
326318
327-
push!(lines, " ")
328-
push!(lines, " Returns:")
329-
push!(lines, " The result of calling the exported function.")
330-
push!(lines, " ")
331-
push!(lines, " Note:")
332-
push!(
333-
lines,
334-
" All inputs must be in row-major (Python/NumPy) order. If you're passing",
335-
)
336-
push!(lines, " arrays from Julia, make sure to transpose them first using:")
337-
push!(lines, " `permutedims(arr, reverse(1:ndims(arr)))`")
338-
push!(lines, " \"\"\"")
339-
push!(lines, " return hlo_call(")
340-
push!(lines, " $(arg_list),")
341-
push!(lines, " source=_hlo_code,")
342-
push!(lines, " )")
343-
push!(lines, "")
344-
345-
# Main block
346-
push!(lines, "if __name__ == \"__main__\":")
347-
push!(lines, " # Load the example inputs")
348-
push!(lines, " inputs = load_inputs()")
349-
push!(lines, " ")
350-
push!(lines, " # Run the function (with JIT compilation)")
351-
push!(lines, " print(\"Running $(function_name) with JIT compilation...\")")
352-
push!(lines, " result = jax.jit(run_$(function_name))(*inputs)")
353-
push!(lines, " print(\"Result:\", result)")
354-
push!(
355-
lines,
356-
" print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')",
357-
)
358-
push!(
359-
lines,
360-
" print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))",
361-
)
319+
if __name__ == \"__main__\":
320+
# Load the example inputs
321+
inputs = load_inputs()
322+
323+
# Run the function (with JIT compilation)
324+
print(\"Running $(function_name) with JIT compilation...\")
325+
result = jax.jit(run_$(function_name))(*inputs)
326+
print(\"Result:\", result)
327+
print(\"Result shape:\", result.shape if hasattr(result, 'shape') else 'scalar')
328+
print(\"Result dtype:\", result.dtype if hasattr(result, 'dtype') else type(result))
329+
"""
362330

363-
# Write the script
364-
write(python_path, join(lines, "\n") * "\n")
331+
write(python_path, strip(script) * "\n")
365332
return nothing
366333
end
367334

src/serialization/Serialization.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ using ..Reactant: Reactant, Compiler
1010

1111
serialization_supported(::Val) = false
1212

13+
const NUMPY_SIMPLE_TYPES = Dict(
14+
Bool => :bool,
15+
Int8 => :int8,
16+
Int16 => :int16,
17+
Int32 => :int32,
18+
Int64 => :int64,
19+
UInt8 => :uint8,
20+
UInt16 => :uint16,
21+
UInt32 => :uint32,
22+
UInt64 => :uint64,
23+
Float16 => :float16,
24+
Float32 => :float32,
25+
Float64 => :float64,
26+
ComplexF16 => :complex16,
27+
ComplexF32 => :complex32,
28+
ComplexF64 => :complex64,
29+
)
30+
1331
include("TFSavedModel.jl")
1432
include("EnzymeJAX.jl")
1533

src/serialization/TFSavedModel.jl

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,10 @@
11
module TFSavedModel
22

3-
using ..Serialization: serialization_supported
3+
using ..Serialization: serialization_supported, NUMPY_SIMPLE_TYPES
44
using ..Reactant: Compiler, MLIR
55

66
# https://github.com/openxla/stablehlo/blob/955fa7e6e3b0a6411edc8ff6fcce1e644440acbd/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py
77

8-
const NUMPY_SIMPLE_TYPES = Dict(
9-
Bool => :bool,
10-
Int8 => :int8,
11-
Int16 => :int16,
12-
Int32 => :int32,
13-
Int64 => :int64,
14-
UInt8 => :uint8,
15-
UInt16 => :uint16,
16-
UInt32 => :uint32,
17-
UInt64 => :uint64,
18-
Float16 => :float16,
19-
Float32 => :float32,
20-
Float64 => :float64,
21-
ComplexF16 => :complex16,
22-
ComplexF32 => :complex32,
23-
ComplexF64 => :complex64,
24-
)
25-
268
struct VariableSignature
279
shape::Vector{Int}
2810
dtype::Symbol

0 commit comments

Comments
 (0)