Skip to content

Commit 62d25bf

Browse files
committed
feat: add size checks
1 parent 0f04c26 commit 62d25bf

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

docs/src/api/serialization.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ Reactant.Serialization.export_as_tf_saved_model
3030

3131
## Exporting to JAX via EnzymeAD
3232

33-
!!! note "No Dependencies Required"
33+
!!! note "Load NPZ"
3434

35-
Unlike TensorFlow SavedModel export, exporting to JAX via EnzymeAD does not require any
36-
Python dependencies at build time. It generates standalone files that can be used with
37-
EnzymeAD/JAX in Python.
35+
This export functionality requires the `NPZ` package to be loaded.
3836

3937
This export functionality generates:
38+
4039
1. A `.mlir` file containing the StableHLO representation of your Julia function
41-
2. Example input `.npy` files with properly transposed arrays (column-major → row-major)
40+
2. Input `.npz` files containing the input arrays for the function
4241
3. A Python script that wraps the function for use with `enzyme_ad.jax.hlo_call`
4342

4443
The generated Python script can be immediately used with JAX and EnzymeAD without any

src/serialization/EnzymeJAX.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ This function:
3333
- `output_dir::Union{String,Nothing}`: Directory where output files will be saved. If
3434
`nothing`, uses a temporary directory and prints the path.
3535
- `function_name::String`: Base name for generated files
36+
- `preserve_sharding::Bool`: Whether to preserve sharding information in the exported
37+
function. Defaults to `true`.
3638
3739
## Returns
3840
@@ -55,8 +57,11 @@ function my_function(x::AbstractArray, y::NamedTuple, z::Number)
5557
end
5658
5759
# Create some example inputs
58-
x = Reactant.to_rarray(Float32[1, 2, 3])
59-
y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9]))
60+
x = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3))
61+
y = (;
62+
x=Reactant.to_rarray(reshape(collect(Float32, 7:12), 2, 3)),
63+
y=Reactant.to_rarray(reshape(collect(Float32, 13:18), 2, 3))
64+
)
6065
z = Reactant.to_rarray(10.0f0; track_numbers=true)
6166
6267
# Export to EnzymeJAX
@@ -73,7 +78,11 @@ result = jax.jit(run_my_function)(*inputs)
7378
```
7479
"""
7580
function export_to_enzymejax(
76-
f, args...; output_dir::Union{String,Nothing}=nothing, function_name::String=string(f)
81+
f,
82+
args...;
83+
output_dir::Union{String,Nothing}=nothing,
84+
function_name::String=string(f),
85+
preserve_sharding::Bool=true,
7786
)
7887
if output_dir === nothing
7988
output_dir = mktempdir(; cleanup=false)
@@ -173,6 +182,15 @@ function _generate_python_script(
173182
"\n",
174183
)
175184

185+
arg_size_checks = [
186+
"assert $(arg_names[i]).shape == $(reverse(info.shape)), f\"Expected shape of $(arg_names[i]) to be $(reverse(info.shape)). Got {$(arg_names[i]).shape} (path: $(info.path))\""
187+
for (i, info) in enumerate(input_info)
188+
]
189+
arg_dtype_checks = [
190+
"assert $(arg_names[i]).dtype == np.dtype('$(Serialization.NUMPY_SIMPLE_TYPES[info.dtype])'), f\"Expected dtype of $(arg_names[i]) to be $(Serialization.NUMPY_SIMPLE_TYPES[info.dtype]). Got {$(arg_names[i]).dtype} (path: $(info.path))\""
191+
for (i, info) in enumerate(input_info)
192+
]
193+
176194
load_inputs = ["npz_data['$(info.key)']" for info in input_info]
177195

178196
# Build the complete Python script
@@ -217,6 +235,8 @@ function _generate_python_script(
217235
arrays from Julia, make sure to transpose them first using:
218236
\`permutedims(arr, reverse(1:ndims(arr)))\`
219237
\"\"\"
238+
$(join(arg_dtype_checks, "\n "))
239+
$(join(arg_size_checks, "\n "))
220240
return hlo_call(
221241
$(arg_list),
222242
source=_hlo_code,

src/serialization/Serialization.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,6 @@ function export_as_tf_saved_model(
127127
)
128128
end
129129

130-
"""
131-
export_to_enzymejax(
132-
f,
133-
args...;
134-
output_dir::String=".",
135-
function_name::String="exported_function",
136-
)
137-
138-
Export a Julia function to EnzymeJAX format for use in Python/JAX.
139-
140-
See [`EnzymeJAX.export_to_enzymejax`](@ref) for details.
141-
"""
142-
function export_to_enzymejax(f, args...; kwargs...)
143-
return EnzymeJAX.export_to_enzymejax(f, args...; kwargs...)
144-
end
130+
const export_to_enzymejax = EnzymeJAX.export_to_enzymejax
145131

146132
end

0 commit comments

Comments
 (0)