-
-
Notifications
You must be signed in to change notification settings - Fork 615
Closed
Labels
Description
using Flux, Enzyme, Statistics, Random
function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end
loss(model, x) = mean(model(x))
model = MeanPool((3, 3))
x = rand(Float32, 3, 3, 2, 2)
enzyme_withgradient(loss, model, x)
Output:
ERROR:
No create nofree of empty function (julia.gc_loaded) julia.gc_loaded)
at context: call fastcc void @julia__PoolDims_14_89677({ [2 x i64], [2 x i64], i64, [2 x i64], [4 x i64], [2 x i64] }* noalias nocapture nofree noundef nonnull writeonly sret({ [2 x i64], [2 x i64], i64, [2 x i64], [4 x i64], [2 x i64] }) align 8 dereferenceable(104) %3, [2 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12, [4 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(48) %11, [4 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %20, [2 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(64) %13) #49, !dbg !75 (julia__PoolDims_14_89677)
Stacktrace:
[1] PoolDims
@ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:20
[2] PoolDims
@ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:43
[3] MeanPool
@ ~/.julia/dev/Flux/src/layers/conv.jl:774
[4] loss
@ ~/.julia/dev/Flux/prova.jl:14
[5] loss
@ ~/.julia/dev/Flux/prova.jl:0
Stacktrace:
[1] PoolDims
@ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:20 [inlined]
[2] PoolDims
@ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:43 [inlined]
[3] MeanPool
@ ~/.julia/dev/Flux/src/layers/conv.jl:774 [inlined]
[4] loss
@ ~/.julia/dev/Flux/prova.jl:14 [inlined]
[5] loss
@ ~/.julia/dev/Flux/prova.jl:0 [inlined]
[6] diffejulia_loss_89542_inner_6wrap
@ ~/.julia/dev/Flux/prova.jl:0
[7] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5317 [inlined]
[8] enzyme_call
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4863 [inlined]
[9] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4735 [inlined]
[10] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:503
[11] enzyme_withgradient(::Function, ::MeanPool{2, 4}, ::Vararg{Any})
@ Main ~/.julia/dev/Flux/test/test_utils.jl:32
[12] top-level scope
@ ~/.julia/dev/Flux/prova.jl:17
Some type information was truncated. Use `show(err)` to see complete types.
cc @wsmoses