diff --git a/src/Enzyme.jl b/src/Enzyme.jl index e263e4d716..6be770720f 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -69,6 +69,20 @@ end return Base.zero(x) end +@inline function Enzyme.tupstack( + data::Tuple{<:RArray, Vararg{<:RArray}}, + outshape::Tuple{Vararg{Int}}, + inshape::Tuple{Vararg{Int}}, +) + res = similar(first(data), outshape..., inshape...) + c = CartesianIndices(outshape) + tail_dims = map(Returns(:), inshape) + for (i, val) in enumerate(data) + @inbounds res[c[i], tail_dims...] = val + end + return res +end + macro register_make_zero_inplace(sym) quote @inline function $sym(prev::RArray{T,N})::Nothing where {T<:AbstractFloat,N} diff --git a/test/autodiff.jl b/test/autodiff.jl index 10daa32ee8..5fd0188c0d 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -102,6 +102,19 @@ end @test res[1] ≈ ones(2, 2) end +f_sincos(x) = map(sin, x) + map(cos, reverse(x)) +f_sincos_jac(x) = [cos(x[i]) * (i == j ? 1 : 0) - sin(x[end-i+1]) * (i == (length(x) - j + 1) ? 1 : 0) for i in 1:length(x), j in 1:length(x)] + +@testset "Forward Jacobian" begin + jac(x) = only(Enzyme.jacobian(Enzyme.Forward, f_sincos, x)) + x_r = Reactant.to_rarray(rand(10)) + + j_gt = Reactant.@allowscalar f_sincos_jac(x_r) + j_reactant = Reactant.@jit jac(x_r) + + @test j_reactant ≈ j_gt +end + mutable struct StateReturn st::Any end