Skip to content

Conversation

@gdalle
Copy link
Contributor

@gdalle gdalle commented Nov 29, 2025


## Kernel compilation

To compile such kernels with Reactant, you need to pass the option `raise=true` to the `@compile` or `@jit` macro.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to perform raise=true to compile with reactant, they'll just run natively as the existing kernel. I you want to convert it to a tensor form to enable linear algebra optimizations, raise needs to be set to true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on a Mac M3 and when I try kernel compilation without raising I get

julia> y = @jit square(x)
ERROR: CUDA driver not found
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] functional
    @ ~/.julia/packages/CUDA/x8d2s/src/initialization.jl:24 [inlined]
  [3] task_local_state!()
    @ CUDA ~/.julia/packages/CUDA/x8d2s/lib/cudadrv/state.jl:77
  [4] active_state
    @ ~/.julia/packages/CUDA/x8d2s/lib/cudadrv/state.jl:110 [inlined]
  [5] cufunction(f::typeof(gpu_square_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/x8d2s/src/compiler/execution.jl:366
  [6] cufunction(f::typeof(gpu_square_kernel!), tt::Type{Tuple{…}})
    @ CUDA ~/.julia/packages/CUDA/x8d2s/src/compiler/execution.jl:365
  [7] launch_configuration(f::ReactantCUDAExt.LLVMFunc{…}; shmem::Int64, max_threads::Int64)
    @ ReactantCUDAExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:619
  [8] ka_with_reactant
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:525 [inlined]
  [9] (::Nothing)(none::typeof(Reactant.ka_with_reactant), none::Int64, none::Nothing, none::KernelAbstractions.Kernel{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [10] launch_config
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:68 [inlined]
 [11] ka_with_reactant
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:504 [inlined]
 [12] call_with_reactant(::typeof(Reactant.ka_with_reactant), ::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0
 [13] (::KernelAbstractions.Kernel{…})(::Any, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
    @ ReactantKernelAbstractionsExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:118
 [14] #kwcall
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:113 [inlined]
 [15] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{}, none::KernelAbstractions.Kernel{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [16] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{}, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:523
 [17] square
    @ ~/Documents/GitHub/Julia/Reactant.jl/test/playground.jl:15 [inlined]
 [18] (::Nothing)(none::typeof(square), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [19] getproperty
    @ ./Base.jl:49 [inlined]
 [20] size
    @ ~/.julia/packages/Reactant/zlIsO/src/TracedRArray.jl:259 [inlined]
 [21] axes
    @ ./abstractarray.jl:98 [inlined]
 [22] similar
    @ ./abstractarray.jl:821 [inlined]
 [23] similar
    @ ./abstractarray.jl:820 [inlined]
 [24] square
    @ ~/Documents/GitHub/Julia/Reactant.jl/test/playground.jl:12 [inlined]
 [25] call_with_reactant(::typeof(square), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0
 [26] make_mlir_fn(f::typeof(square), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/zlIsO/src/TracedUtils.jl:345
 [27] make_mlir_fn
    @ ~/.julia/packages/Reactant/zlIsO/src/TracedUtils.jl:275 [inlined]
 [28] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(square), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1614
 [29] compile_mlir!
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1576 [inlined]
 [30] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3524
 [31] compile_xla
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3496 [inlined]
 [32] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3600
 [33] top-level scope
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:2669
Some type information was truncated. Use `show(err)` to see complete types.


To compile such kernels with Reactant, you need to pass the option `raise=true` to the `@compile` or `@jit` macro.
Furthermore, the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package needs to be loaded (even on non-NVIDIA hardware).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm this is a bug, in principle we should force load this for you [though in the bg it will load cuda.jl]

Copy link
Contributor Author

@gdalle gdalle Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without loading CUDA first, I get the error I mentioned on Discourse:

julia> y = @jit raise=true square(x)
ERROR: MethodError: no method matching ka_with_reactant(::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
The function `ka_with_reactant` exists, but no method is defined for this combination of argument types.
Attempted to raise a KernelAbstractions kernel with Reactant but CUDA.jl is not loaded.
Load CUDA.jl using `using CUDA`. You might need to restart the Julia process (even if Revise.jl is loaded).
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::typeof(Reactant.ka_with_reactant), ::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:944
  [3] (::KernelAbstractions.Kernel{…})(::Any, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
    @ ReactantKernelAbstractionsExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:118
...


## Differentiated kernel

In addition, if you want to compute derivatives of your kernel with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl), the option `raise_first=true` also becomes necessary.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more what raising is. I think this doc would be better written as a tutorial on raising rather than kernels

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re this specifically, currently we only support differentiation rules for the raised tensors [the internal kernel representaiton derivatives are in progress cc @Pangoraw @avik-pal etc].

For now raising will enable differentiation to succeed, but also raising must be performed prior to differentiation, hence raisefirst

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more what raising is. I think this doc would be better written as a tutorial on raising rather than kernels

No I can't. I don't know the first thing about how raising works (besides what @yolhan83 has taught me on Discourse), I only know that it seems to be necessary on my laptop for handwritten kernels to get compiled.

The goals of this PR are:

  1. Showing that interoperability with custom kernels is an important aspect for Reactant users, that needs to be documented.
  2. Getting the page started so that people who actually know what they're talking about can finish writing proper documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrote a small section about raising, if you want to include or i can open a PR after this one is merged: 5fe6e01

@codecov
Copy link

codecov bot commented Nov 29, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 64.92%. Comparing base (b39a1fc) to head (85af24b).
⚠️ Report is 209 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1920      +/-   ##
==========================================
- Coverage   68.16%   64.92%   -3.25%     
==========================================
  Files         109      121      +12     
  Lines       11779    13139    +1360     
==========================================
+ Hits         8029     8530     +501     
- Misses       3750     4609     +859     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yolhan83
Copy link
Contributor

yolhan83 commented Nov 30, 2025

Hey all, indeed a doc about raising may be interesting but a kernel part may be as important, I think RArray auto trigger CUDABackend on KA which mean without raising it may only work on CUDA, a thing to test is when not raising, setting the backend of KA by hand and the Reactant backend too but not sure Reactant will be able to send the Metal or AMD device code to the XLA backend anyway.

@@ -0,0 +1,73 @@
# Kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a random thought, but maybe the tutorial can be called something like "Compute Kernels" or "GPU Kernels" ? IMHO, kernel has a lot of different meanings in computing.

@gdalle
Copy link
Contributor Author

gdalle commented Nov 30, 2025

So, as I understand it:

  • Reactant compilation should work without raise=true, but on my Mac it doesn't
  • Reactant compilation with raise=true should work without using CUDA, but on my Mac it doesn't

In light of this, should we wait for this to be fixed or should we document the behavior that we know for sure works across platforms?

@Pangoraw
Copy link
Collaborator

Pangoraw commented Dec 2, 2025

Reactant compilation should work without raise=true, but on my Mac it doesn't

#1923 would fix this for KA kernels, it should already work for CUDA.jl kernels.

@gdalle
Copy link
Contributor Author

gdalle commented Dec 2, 2025

Thanks! From what I see, the fix lives in Reactant's CUDA extension, so it will remain necessary to run using CUDA even on platforms where it is not functional?

@Pangoraw
Copy link
Collaborator

Pangoraw commented Dec 2, 2025

Yes it uses CUDA.jl to extract the kernel IR. So it will need to be loaded (either by the user or the KA Reactant extension)

we should force load this for you

@wsmoses do you know how we can achieve that?

@gdalle
Copy link
Contributor Author

gdalle commented Dec 2, 2025

By making CUDA a hard dep of Reactant? Why isn't that the case if you need it for kernels?

@gdalle
Copy link
Contributor Author

gdalle commented Dec 2, 2025

I understand the need to keep it as a weak dep though, and I don't think Pkg manipulations should be performed under the hood. So the current error message might be good enough, I was just a bit surprised that installing CUDA on my Mac would actually solve the issue (which it did).

@Pangoraw
Copy link
Collaborator

Pangoraw commented Dec 4, 2025

By making CUDA a hard dep of Reactant? Why isn't that the case if you need it for kernels?

We need for CUDA kernels and KA kernels, so it can be fair to require using CUDA even in the case of a non-CUDA setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants