-
Notifications
You must be signed in to change notification settings - Fork 19
Closed
Description
Hi, I am trying to set up a custom training loop following the code at the bottom of the README, but I have not been able to make it work. I am relatively new to Julia, so I am probably not the best at debugging this stuff right now. Any pointers are appreciated!
Here is the code:
using Flux
using Turing, AdvancedVI, Distributions, DynamicPPL, StatsFuns, DiffResults
using Turing: Variational
using StatsBase
function vi_custom(model, q_init=nothing; n_mc, n_iter, tol, optimizer)
varinfo = DynamicPPL.VarInfo(model)
num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)])
logπ = Variational.make_logjoint(model)
variational_objective = Variational.ELBO()
alg = ADVI(n_mc, n_iter)
# Set up q
if isnothing(q_init)
μ = randn(num_params)
σ = StatsFuns.softplus.(randn(num_params))
else
μ, σs = StatsBase.params(q_init)
σ = StatsFuns.invsoftplus.(σs)
end
θ = vcat(μ, σ)
q = Variational.meanfield(model)
converged = false
step = 1
diff_result = DiffResults.GradientResult(θ)
while (step <= n_iter) && !converged
# 1. Compute gradient and objective value; results are stored in `diff_results`
AdvancedVI.grad!(variational_objective, alg, q, model, θ, diff_result)
# 2. Extract gradient from `diff_result`
∇ = DiffResults.gradient(diff_result)
# 3. Apply optimizer, e.g. multiplying by step-size
Δ = apply!(optimizer, θ, ∇)
# 4. Update parameters
θ_prev = copy(θ)
@. θ = θ - Δ
# Check convergence
converged = sqrt(sum((θ - θ_prev).^2)) < tol
step += 1
end
return θ, step, q
end
@model norm(z) = begin
s ~ InverseGamma(1, 1)
μ ~ Normal(0, sqrt(s))
# likelihood
z .~ Normal(μ, sqrt(s))
end
z = rand(Normal(1., 2.), (200, 1));
θ, step, q = vi_custom(norm(z); n_mc=25, n_iter=20000, tol=0.01, optimizer = Flux.ADAM())
And here is the stacktrace and error I get:
ERROR: LoadError: MethodError: no method matching (::ELBO)(::ADVI{AdvancedVI.ForwardDiffAD{40}}, ::Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1},Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate}, ::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}})
Closest candidates are:
Any(::Any, ::Any, ::Any, ::Any; kwargs...) at /.julia/packages/AdvancedVI/PaSeO/src/objectives.jl:5
Any(::AbstractRNG, ::VariationalInference, ::Any, ::Model, ::Any; weight, kwargs...) at /.julia/packages/Turing/3goIa/src/variational/VariationalInference.jl:57
Stacktrace:
[1] (::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}})(::Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}) at /.julia/packages/AdvancedVI/PaSeO/src/AdvancedVI.jl:140
[2] vector_mode_dual_eval(::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/apiutils.jl:37
[3] vector_mode_gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:103
[4] gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}, ::Val{true}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:35
[5] gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:33
[6] grad!(::ELBO, ::ADVI{AdvancedVI.ForwardDiffAD{40}}, ::Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate}, ::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}}, ::Array{Float64,1}, ::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}) at /.julia/packages/AdvancedVI/PaSeO/src/AdvancedVI.jl:149
[7] vi_custom(::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}}, ::Nothing; n_mc::Int64, n_iter::Int64, tol::Float64, optimizer::ADAM) at custom_training_loop.jl:27
[8] top-level scope at custom_training_loop.jl:53
[9] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1088
in expression starting at custom_training_loop.jl:53
Metadata
Metadata
Assignees
Labels
No labels