Skip to content

custom training loop implementation help #19

@natalieklein229

Description

@natalieklein229

Hi, I am trying to set up a custom training loop following the code at the bottom of the README, but I have not been able to make it work. I am relatively new to Julia, so I am probably not the best at debugging this stuff right now. Any pointers are appreciated!

Here is the code:

using Flux
using Turing, AdvancedVI, Distributions, DynamicPPL, StatsFuns, DiffResults
using Turing: Variational
using StatsBase

function vi_custom(model, q_init=nothing; n_mc, n_iter, tol, optimizer)
    varinfo = DynamicPPL.VarInfo(model)
    num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym  keys(varinfo.metadata)])
    logπ = Variational.make_logjoint(model)
    variational_objective = Variational.ELBO()
    alg = ADVI(n_mc, n_iter)
    # Set up q
    if isnothing(q_init)
        μ = randn(num_params)
        σ = StatsFuns.softplus.(randn(num_params))
    else
        μ, σs = StatsBase.params(q_init)
        σ = StatsFuns.invsoftplus.(σs)
    end
    θ = vcat(μ, σ)
    q = Variational.meanfield(model)
    converged = false
    step = 1
    diff_result = DiffResults.GradientResult(θ)
    while (step <= n_iter) && !converged
        # 1. Compute gradient and objective value; results are stored in `diff_results`
        AdvancedVI.grad!(variational_objective, alg, q, model, θ, diff_result)
        # 2. Extract gradient from `diff_result`= DiffResults.gradient(diff_result)
        # 3. Apply optimizer, e.g. multiplying by step-size
        Δ = apply!(optimizer, θ, ∇)
        # 4. Update parameters
        θ_prev = copy(θ)
        @. θ = θ - Δ
        # Check convergence
        converged = sqrt(sum((θ - θ_prev).^2)) < tol
        step += 1
    end
    return θ, step, q
end

@model norm(z) = begin
    s ~ InverseGamma(1, 1)
    μ ~ Normal(0, sqrt(s))
    # likelihood
    z .~ Normal(μ, sqrt(s))
end

z = rand(Normal(1., 2.), (200, 1));

θ, step, q = vi_custom(norm(z); n_mc=25, n_iter=20000, tol=0.01, optimizer = Flux.ADAM())

And here is the stacktrace and error I get:

ERROR: LoadError: MethodError: no method matching (::ELBO)(::ADVI{AdvancedVI.ForwardDiffAD{40}}, ::Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1},Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate}, ::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}})
Closest candidates are:
  Any(::Any, ::Any, ::Any, ::Any; kwargs...) at /.julia/packages/AdvancedVI/PaSeO/src/objectives.jl:5
  Any(::AbstractRNG, ::VariationalInference, ::Any, ::Model, ::Any; weight, kwargs...) at /.julia/packages/Turing/3goIa/src/variational/VariationalInference.jl:57
Stacktrace:
 [1] (::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}})(::Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}) at /.julia/packages/AdvancedVI/PaSeO/src/AdvancedVI.jl:140
 [2] vector_mode_dual_eval(::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/apiutils.jl:37
 [3] vector_mode_gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:103
 [4] gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}, ::Val{true}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:35
 [5] gradient!(::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4,Array{ForwardDiff.Dual{ForwardDiff.Tag{AdvancedVI.var"#f#19"{ELBO,ADVI{AdvancedVI.ForwardDiffAD{40}},Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate},Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}},Tuple{}},Float64},Float64,4},1}}) at /.julia/packages/ForwardDiff/sdToQ/src/gradient.jl:33
 [6] grad!(::ELBO, ::ADVI{AdvancedVI.ForwardDiffAD{40}}, ::Bijectors.TransformedDistribution{DistributionsAD.TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}},Stacked{Tuple{Bijectors.Exp{0},Identity{0}},2},Multivariate}, ::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}}, ::Array{Float64,1}, ::DiffResults.MutableDiffResult{1,Float64,Tuple{Array{Float64,1}}}) at /.julia/packages/AdvancedVI/PaSeO/src/AdvancedVI.jl:149
 [7] vi_custom(::Model{var"#33#34",(:z,),(),(),Tuple{Array{Float64,2}},Tuple{}}, ::Nothing; n_mc::Int64, n_iter::Int64, tol::Float64, optimizer::ADAM) at custom_training_loop.jl:27
 [8] top-level scope at custom_training_loop.jl:53
 [9] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1088
in expression starting at custom_training_loop.jl:53

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions