Skip to content

Commit 04b97ae

Browse files
Copilotavik-pal
andcommitted
Add comprehensive tests for export_to_enzymeax
Co-authored-by: avik-pal <[email protected]>
1 parent ae75160 commit 04b97ae

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
using Reactant
2+
using Test
3+
4+
@testset "Export to EnzymeJAX - Multi-dimensional Arrays" begin
5+
tmpdir = mktempdir()
6+
7+
try
8+
# Define a function with 2D arrays
9+
function matrix_multiply(x, y)
10+
return x * y
11+
end
12+
13+
# Create 2D arrays - Julia uses column-major order
14+
x = Reactant.to_rarray(Float32[1 2 3; 4 5 6]) # 2x3 matrix
15+
y = Reactant.to_rarray(Float32[7 8; 9 10; 11 12]) # 3x2 matrix
16+
17+
# Export to EnzymeJAX
18+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
19+
matrix_multiply, x, y;
20+
output_dir=tmpdir,
21+
function_name="matrix_multiply"
22+
)
23+
24+
@test isfile(mlir_path)
25+
@test isfile(python_path)
26+
@test length(input_paths) == 2
27+
28+
# Read Python file and check for correct shape information
29+
python_content = read(python_path, String)
30+
31+
# The shapes should be transposed for Python (row-major)
32+
# Julia x: (2, 3) -> Python: (3, 2)
33+
# Julia y: (3, 2) -> Python: (2, 3)
34+
@test occursin("(3, 2)", python_content) # Transposed shape of x
35+
@test occursin("(2, 3)", python_content) # Transposed shape of y
36+
37+
println("✓ Multi-dimensional array export test passed!")
38+
39+
finally
40+
rm(tmpdir; recursive=true, force=true)
41+
end
42+
end
43+
44+
@testset "Export to EnzymeJAX - 3D Arrays" begin
45+
tmpdir = mktempdir()
46+
47+
try
48+
# Define a function with 3D arrays (like image data)
49+
function add_3d(x, y)
50+
return x .+ y
51+
end
52+
53+
# Create 3D arrays - e.g., (height, width, channels, batch)
54+
# Julia: (28, 28, 1, 4) -> Python: (4, 1, 28, 28)
55+
x = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4))
56+
y = Reactant.to_rarray(rand(Float32, 28, 28, 1, 4))
57+
58+
# Export to EnzymeJAX
59+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
60+
add_3d, x, y;
61+
output_dir=tmpdir,
62+
function_name="add_3d"
63+
)
64+
65+
@test isfile(mlir_path)
66+
@test isfile(python_path)
67+
68+
# Check that Python file mentions the transposed shape
69+
python_content = read(python_path, String)
70+
@test occursin("(4, 1, 28, 28)", python_content)
71+
72+
println("✓ 3D array export test passed!")
73+
74+
finally
75+
rm(tmpdir; recursive=true, force=true)
76+
end
77+
end
78+
79+
@testset "Export to EnzymeJAX - File Content Verification" begin
80+
tmpdir = mktempdir()
81+
82+
try
83+
function simple_fn(x)
84+
return x .* 2.0f0
85+
end
86+
87+
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0])
88+
89+
mlir_path, python_path, input_paths = Reactant.Serialization.export_to_enzymeax(
90+
simple_fn, x;
91+
output_dir=tmpdir,
92+
function_name="test_fn"
93+
)
94+
95+
# Verify MLIR contains necessary elements
96+
mlir_content = read(mlir_path, String)
97+
@test occursin("module", mlir_content)
98+
99+
# Verify Python file structure
100+
python_content = read(python_path, String)
101+
@test occursin("import jax", python_content)
102+
@test occursin("import numpy as np", python_content)
103+
@test occursin("from enzyme_ad.jax import hlo_call", python_content)
104+
@test occursin("def run_test_fn(arg1)", python_content)
105+
@test occursin("source=_hlo_code", python_content)
106+
@test occursin("jax.jit(run_test_fn)", python_content)
107+
108+
println("✓ File content verification test passed!")
109+
110+
finally
111+
rm(tmpdir; recursive=true, force=true)
112+
end
113+
end
114+
115+
println("\n✅ All comprehensive tests passed!")

0 commit comments

Comments
 (0)