Skip to content

Conversation

Copy link

Copilot AI commented Dec 6, 2025

Refactored export_to_enzymejax to use Reactant's existing make_tracer infrastructure as suggested by @wsmoses. Instead of manually creating an inverse map from traced args to concrete args, we now directly iterate over mlir_fn_res.seen_args which is an OrderedIdDict that already contains this mapping.

Changes

  • Removed manual invmap creation (lines 88-91)
  • Now iterate directly over mlir_fn_res.seen_args which maps concrete args to traced args
  • Filter by checking if the traced arg is in mlir_fn_res.linear_args
  • Maintains the same functional behavior while using the existing infrastructure
  • Renamed function from export_to_enzymeax to export_to_enzymejax for naming consistency

This aligns with the design of Reactant's compilation infrastructure where seen_args is already populated during the tracing process.

Original prompt

This section details on the original issue you should resolve

<issue_title>Export a compile function to jax via enzyme_ad</issue_title>
<issue_description>Follow this tutorial for a setup

# [Exporting Lux Models to Jax (via EnzymeJAX & Reactant)](@id exporting_to_stablehlo)

In this manual, we will go over how to export Lux models to StableHLO and use
[EnzymeJAX](https://github.com/EnzymeAD/Enzyme-JAX) to run integrate Lux models with
JAX. We assume that users are familiar with
[Reactant compilation of Lux models](@ref reactant-compilation).

```@example exporting_to_stablehlo
using Lux, Reactant, Random

const dev = reactant_device()
```

We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`.

```@example exporting_to_stablehlo
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
nothing # hide
```

Generate an example input.

```@example exporting_to_stablehlo
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev
nothing # hide
```

Now instead of compiling the model, we will use `Reactant.@code_hlo` to generate the
StableHLO code.

```@example exporting_to_stablehlo
hlo_code = @code_hlo model(x, ps, st)
```

Now we just save this into an `mlir` file.

```@example exporting_to_stablehlo
write("exported_lux_model.mlir", string(hlo_code))
nothing # hide
```

Now we define a python script to run the model using EnzymeJAX.

```python
from enzyme_ad.jax import hlo_call

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
    code = file.read()


def run_lux_model(
    x,
    weight1,
    bias1,
    weight3,
    bias3,
    weight6_1,
    bias6_1,
    weight6_2,
    bias6_2,
    weight6_3,
    bias6_3,
):
    return hlo_call(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
        source=code,
    )


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
    jax.jit(run_lux_model)(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
    )
)
```

We need to automate generating the python file. dump the inputs to the function into files that are read from the files in python, make sure to transpose before saving.</issue_description>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI changed the title [WIP] Export compile function to jax via enzyme_ad Add export_to_enzymeax for automated JAX/EnzymeAD integration Dec 6, 2025
Copilot AI requested a review from avik-pal December 6, 2025 17:37
Copilot finished work on behalf of avik-pal December 6, 2025 17:37
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

JuliaFormatter

[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

simple_add, x, y;
output_dir=tmpdir,
function_name="simple_add"


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

matrix_multiply, x, y;
output_dir=tmpdir,
function_name="matrix_multiply"


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

add_3d, x, y;
output_dir=tmpdir,
function_name="add_3d"


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

simple_fn, x;
output_dir=tmpdir,
function_name="test_fn"


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

JuliaFormatter

[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶


[JuliaFormatter] reported by reviewdog 🐶

@testset "Export to EnzymeJAX" begin
# Create a temporary directory for the export
tmpdir = mktempdir()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

function simple_add(x, y)
return x .+ y
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Create some example inputs
x = Reactant.to_rarray(Float32[1, 2, 3])
y = Reactant.to_rarray(Float32[4, 5, 6])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines 20 to 22
simple_add, x, y;
output_dir=tmpdir,
function_name="simple_add"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
simple_add, x, y;
output_dir=tmpdir,
function_name="simple_add"
simple_add, x, y; output_dir=tmpdir, function_name="simple_add"

output_dir=tmpdir,
function_name="simple_add"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@test occursin("(4, 1, 28, 28)", python_content)

println("✓ 3D array export test passed!")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change


@testset "Export to EnzymeJAX - File Content Verification" begin
tmpdir = mktempdir()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

function simple_fn(x)
return x .* 2.0f0
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

end

x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0, 4.0])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines 90 to 92
simple_fn, x;
output_dir=tmpdir,
function_name="test_fn"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
simple_fn, x;
output_dir=tmpdir,
function_name="test_fn"
simple_fn, x; output_dir=tmpdir, function_name="test_fn"


for (i, arg) in enumerate(args)
# Convert to array if needed
arr = _to_array(arg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot this shouldn't be necessary. You should be able to re-use the existing infrastructure within Reactant.jl's make_tracer infrastructure to get a list of all the actual linearized arguments [in fact the same linearization used by compilation is required].

cc @avik-pal

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to use the existing seen_args infrastructure. Now iterating directly over mlir_fn_res.seen_args (OrderedIdDict mapping concrete args → traced args) and filtering by linear_args instead of creating a manual inverse map. Commit: 89b7690

end

"""
export_to_enzymeax(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot rename to export_to_enzymejax

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to export_to_enzymejax across all files. Commit: a53e20e

@avik-pal
Copy link
Collaborator

avik-pal commented Dec 6, 2025

On running the export pipeline (identical to how compile works)

"""
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.

This script was generated by Reactant.Serialization.export_to_enzymeax().
"""

from enzyme_ad.jax import hlo_call
import jax
import jax.numpy as jnp
import numpy as np
import os

# Get the directory of this script
_script_dir = os.path.dirname(os.path.abspath(__file__))

# Load the MLIR/StableHLO code
with open(os.path.join(_script_dir, "my_function_0.mlir"), "r") as f:
    _hlo_code = f.read()

def load_inputs():
    """Load the example inputs that were exported from Julia."""
    inputs = []
    inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_1.npy")))
    inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_2.npy")))
    inputs.append(np.load(os.path.join(_script_dir, "my_function_0_input_3.npy")))
    return tuple(inputs)

def run_my_function(arg1, arg2, arg3):
    """
    Call the exported Julia function via EnzymeJAX.

    Args:
        arg1: Array of shape (3,) and dtype float32
        arg2: Array of shape (3,) and dtype float32
        arg3: Array of shape (3,) and dtype float32

    Returns:
        The result of calling the exported function.

    Note:
        All inputs must be in row-major (Python/NumPy) order. If you're passing
        arrays from Julia, make sure to transpose them first using:
        `permutedims(arr, reverse(1:ndims(arr)))`
    """
    return hlo_call(
        arg1, arg2, arg3,
        source=_hlo_code,
    )

if __name__ == "__main__":
    # Load the example inputs
    inputs = load_inputs()

    # Run the function (with JIT compilation)
    print("Running my_function with JIT compilation...")
    result = jax.jit(run_my_function)(*inputs)
    print("Result:", result)

Copy link
Author

Copilot AI commented Dec 6, 2025

Just as a heads up, I was blocked by some firewall rules while working on your feedback. Expand below for details.

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • 127.255.255.255
    • Triggering command: REDACTED, pid is -1 (packet block)
  • https://api.github.com/repos/EnzymeAD/Enzyme.jl/tarball/820f06722a87d9544f42679182eb0850690f9b45
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/EnzymeAD/Enzyme.jl/tarball/af32c93dfb17c5f78c90e45f4547e6733552d78a
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/EnzymeAD/Reactant_jll.jl/tarball/e103ed74bf68b1648ca9a673899fa6afcf12fb88
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/FluxML/Functors.jl/tarball/60a0339f28a233601cb74468032b5c302d5067de
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/FluxML/MacroTools.jl/tarball/1e0228a030642014fe5cfe68c2c0a818f9e3f522
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/Enzyme_jll.jl/tarball/6096df88a1bc09afb9a1c85d4e54ed085a95e799
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/LLVMExtra_jll.jl/tarball/8e76807afb59ebb833e9b131ebf1a8c006510f33
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/LLVMOpenMP_jll.jl/tarball/eb62a3deb62fc6d8822c0c4bef73e4412419c5d8
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/LibTracyClient_jll.jl/tarball/d2bc4e1034b2d43076b50f0e34ea094c2cb0a717
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/MbedTLS_jll.jl/tarball/ff69a2b1330bcb730b9ac1ab7dd680176f5896b8
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaCollections/OrderedCollections.jl/tarball/05868e21324cede2207c6f0f466b4bfef6d5e7ee
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaGPU/Adapt.jl/tarball/7e35fca2bdfba44d797c53dfe63a51fabf39bfc0
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaGPU/GPUArrays.jl/tarball/83cf05ab16a73219e5f6bd1bdfa9848fa24ac627
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaGPU/GPUCompiler.jl/tarball/6e5a25bc455da8e8d88b6b7377e341e9af1929f0
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaIO/CodecZlib.jl/tarball/962834c22b66e32aa10f7611c08c8ca4e20749a9
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaIO/ObjectFile.jl/tarball/22faba70c22d2f03e60fbc61da99c4ebfc3eb9ba
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaIO/StructIO.jl/tarball/c581be48ae1cbf83e899b14c07a807e1787512cc
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaIO/TranscodingStreams.jl/tarball/0c45878dcfdcfa8480052b6ab162cdd138781742
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaInterop/CEnum.jl/tarball/389ad5c84de1ae7cf0e28e381131c98ea87d54fc
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLLVM/LLVM.jl/tarball/ce8614210409eaa54ed5968f4b50aa96da7ae543
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLang/Compat.jl/tarball/9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLang/MbedTLS.jl/tarball/c067a280ddc25f196b5e7df3877c6b226d390aaf
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLang/PrecompileTools.jl/tarball/07a921781cab75691315adc645096ed5e370cb77
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLang/ScopedValues.jl/tarball/c3b2323466378a2ba15bea4b2f73b081e022f473
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaLogging/LoggingExtras.jl/tarball/f00544d95982ea270145636c181ceda21c4e2575
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaObjects/ConstructionBase.jl/tarball/b4b092499347b18a015186eae3042f72267106cb
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPackaging/JLLWrappers.jl/tarball/0533e564aae234aff59ab625543145446d8b6ec2
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPackaging/Preferences.jl/tarball/0f27480397253da18fe2c12a4ba4eb9eb208bf3d
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPackaging/Requires.jl/tarball/62389eeff14780bfe55195b7204c0d8738436d64
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPackaging/Scratch.jl/tarball/9b81b8393e50b7d4e6d0a9f14e192294d3b7c109
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPackaging/SimpleBufferStream.jl/tarball/f305871d2f381d21527c770d4788c06c097c9bc1
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaPluto/ExpressionExplorer.jl/tarball/4a8c0a9eebf807ac42f0f6de758e60a20be25ffb
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaServices/ConcurrentUtilities.jl/tarball/d9d26935a0bcffc87d2613ce14c527c99fc543fd
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaServices/ExceptionUnwrapping.jl/tarball/d36f682e590a83d63d1c7dbd287573764682d12a
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaTesting/ExprTools.jl/tarball/27415f162e6028e81c72b82ef756bf321213b6ec
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaWeb/HTTP.jl/tarball/5e6fe50ae7f23d171f44e311c2960294aaa0beb5
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaWeb/OpenSSL.jl/tarball/1d1aaa7d449b58415f97d2839c318b70ffb525a0
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/JuliaWeb/URIs.jl/tarball/bef26fb046d031353ef97a82e3fdb6afe7f21b1a
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/fredrikekre/EnumX.jl/tarball/bddad79635af6aec424f53ed8aad5d7555dc6f00
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/jmert/BitFlags.jl/tarball/0691e34b3bb8be9307330f88d1a3c3f25466c24d
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/simonster/Reexport.jl/tarball/45e428421666073eab6f2da5c9d310d99bb12f9b
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/topolarity/Tracy.jl/tarball/73e3ff50fd3990874c59fef0f35d10644a1487bc
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • https://api.github.com/repos/vchuravy/HashArrayMappedTries.jl/tarball/2eaa69a7cab70a52b9687c8bf950a5a93ec895ae
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (http block)
  • pkg.julialang.org
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate() (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Copilot AI requested a review from wsmoses December 6, 2025 19:16
Copilot finished work on behalf of wsmoses December 6, 2025 19:16
@wsmoses
Copy link
Member

wsmoses commented Dec 6, 2025

@avik-pal specifically for the GB debugging sake, could we also figure out how to pass the sharding info as well?

Copilot AI changed the title Add export_to_enzymeax for automated JAX/EnzymeAD integration Add export_to_enzymejax for automated JAX/EnzymeAD integration Dec 6, 2025
Copilot finished work on behalf of avik-pal December 6, 2025 19:20
@avik-pal
Copy link
Collaborator

avik-pal commented Dec 6, 2025

@avik-pal specifically for the GB debugging sake, could we also figure out how to pass the sharding info as well?

should work now

@avik-pal avik-pal force-pushed the copilot/export-compile-function-to-jax branch from 73c2efa to 68eaf8a Compare December 6, 2025 21:37
@avik-pal avik-pal force-pushed the copilot/export-compile-function-to-jax branch from 68eaf8a to c86486d Compare December 6, 2025 22:08
@avik-pal avik-pal force-pushed the copilot/export-compile-function-to-jax branch from e39cd36 to ac13dff Compare December 6, 2025 23:28
@avik-pal avik-pal marked this pull request as ready for review December 6, 2025 23:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Export a compile function to jax via enzyme_ad

3 participants