Skip to content

Commit a53e20e

Browse files
Copilotavik-pal
andcommitted
Rename export_to_enzymeax to export_to_enzymejax
Co-authored-by: wsmoses <[email protected]>
1 parent 89b7690 commit a53e20e

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

docs/src/api/serialization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ The generated Python script can be immediately used with JAX and EnzymeAD withou
4545
additional Julia dependencies.
4646

4747
```@docs
48-
Reactant.Serialization.export_to_enzymeax
48+
Reactant.Serialization.export_to_enzymejax
4949
```

src/serialization/EnzymeJAX.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EnzymeJAX
33
using ..Reactant: Reactant, Compiler, MLIR, Serialization
44

55
"""
6-
export_to_enzymeax(
6+
export_to_enzymejax(
77
f,
88
args...;
99
output_dir::String=".",
@@ -50,7 +50,7 @@ x = Reactant.to_rarray(Float32[1, 2, 3])
5050
y = (; x=Reactant.to_rarray(Float32[4, 5, 6]), y=Reactant.to_rarray(Float32[7, 8, 9]))
5151
5252
# Export to EnzymeJAX
53-
python_file_path = Reactant.Serialization.export_to_enzymeax(my_function, x, y)
53+
python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y)
5454
```
5555
5656
Then in Python:
@@ -62,7 +62,7 @@ import jax
6262
result = jax.jit(run_my_function)(*inputs)
6363
```
6464
"""
65-
function export_to_enzymeax(
65+
function export_to_enzymejax(
6666
f, args...; output_dir::Union{String,Nothing}=nothing, function_name::String=string(f)
6767
)
6868
if output_dir === nothing
@@ -208,7 +208,7 @@ function _generate_python_script(
208208
\"\"\"
209209
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
210210
211-
This script was generated by Reactant.Serialization.export_to_enzymeax().
211+
This script was generated by Reactant.Serialization.export_to_enzymejax().
212212
\"\"\"
213213
214214
from enzyme_ad.jax import hlo_call

src/serialization/Serialization.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function export_as_tf_saved_model(
128128
end
129129

130130
"""
131-
export_to_enzymeax(
131+
export_to_enzymejax(
132132
f,
133133
args...;
134134
output_dir::String=".",
@@ -137,10 +137,10 @@ end
137137
138138
Export a Julia function to EnzymeJAX format for use in Python/JAX.
139139
140-
See [`EnzymeJAX.export_to_enzymeax`](@ref) for details.
140+
See [`EnzymeJAX.export_to_enzymejax`](@ref) for details.
141141
"""
142-
function export_to_enzymeax(f, args...; kwargs...)
143-
return EnzymeJAX.export_to_enzymeax(f, args...; kwargs...)
142+
function export_to_enzymejax(f, args...; kwargs...)
143+
return EnzymeJAX.export_to_enzymejax(f, args...; kwargs...)
144144
end
145145

146146
end

test/export_enzymeax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Test
1616
y = Reactant.to_rarray(Float32[4, 5, 6])
1717

1818
# Export to EnzymeJAX
19-
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
19+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax(
2020
simple_add, x, y;
2121
output_dir=tmpdir,
2222
function_name="simple_add"
@@ -46,7 +46,7 @@ using Test
4646
@test filesize(input_path) > 0
4747
end
4848

49-
println("✓ All export_to_enzymeax tests passed!")
49+
println("✓ All export_to_enzymejax tests passed!")
5050
println(" - MLIR file created: $(mlir_path)")
5151
println(" - Python file created: $(python_path)")
5252
println(" - Input files created: $(length(input_paths))")

test/export_enzymeax_comprehensive.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Test
1515
y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix
1616

1717
# Export to EnzymeJAX
18-
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
18+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax(
1919
matrix_multiply, x, y;
2020
output_dir=tmpdir,
2121
function_name="matrix_multiply"
@@ -56,7 +56,7 @@ end
5656
y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4))
5757

5858
# Export to EnzymeJAX
59-
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
59+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax(
6060
add_3d, x, y;
6161
output_dir=tmpdir,
6262
function_name="add_3d"
@@ -86,7 +86,7 @@ end
8686

8787
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0])
8888

89-
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
89+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymejax(
9090
simple_fn, x;
9191
output_dir=tmpdir,
9292
function_name="test_fn"

0 commit comments

Comments
 (0)