From 32765a2693e2b8f25bcd1f963172d1cc4f39b80f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Mar 2025 14:51:13 +0000 Subject: [PATCH 1/3] Make Gibbs work with step_warmup --- src/mcmc/gibbs.jl | 110 ++++++++++++++++++++++++++++++++++--- src/mcmc/repeat_sampler.jl | 27 +++++++++ test/mcmc/gibbs.jl | 96 ++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 8 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 5af01388e5..2e7e0478e7 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -405,20 +405,75 @@ end varinfo(state::GibbsState) = state.vi -function DynamicPPL.initialstep( +""" +Initialise a VarInfo for the Gibbs sampler. + +This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to +support calling both step and step_warmup as the initial step. DynamicPPL initialstep is +incompatible with step_warmup. +""" +function initial_varinfo(rng, model, spl, initial_params) + vi = DynamicPPL.default_varinfo(rng, model, spl) + + # Update the parameters if provided. + if initial_params !== nothing + vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model) + + # Update joint log probability. + # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 + # and https://github.com/TuringLang/Turing.jl/issues/1563 + # to avoid that existing variables are resampled + vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext())) + end + return vi +end + +function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - vi::DynamicPPL.AbstractVarInfo; + spl::DynamicPPL.Sampler{<:Gibbs}; initial_params=nothing, kwargs..., ) alg = spl.alg varnames = alg.varnames samplers = alg.samplers + vi = initial_varinfo(rng, model, spl, initial_params) vi, states = gibbs_initialstep_recursive( - rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs... + rng, + model, + AbstractMCMC.step, + varnames, + samplers, + vi; + initial_params=initial_params, + kwargs..., + ) + return Transition(model, vi), GibbsState(vi, states) +end + +function AbstractMCMC.step_warmup( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}; + initial_params=nothing, + kwargs..., +) + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + vi = initial_varinfo(rng, model, spl, initial_params) + + vi, states = gibbs_initialstep_recursive( + rng, + model, + AbstractMCMC.step_warmup, + varnames, + samplers, + vi; + initial_params=initial_params, + kwargs..., ) return Transition(model, vi), GibbsState(vi, states) end @@ -427,9 +482,20 @@ end Take the first step of MCMC for the first component sampler, and call the same function recursively on the remaining samplers, until no samplers remain. Return the global VarInfo and a tuple of initial states for all component samplers. + +The `step_function` argument should always be either AbstractMCMC.step or +AbstractMCMC.step_warmup. """ function gibbs_initialstep_recursive( - rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs... + rng, + model, + step_function::Function, + varname_vecs, + samplers, + vi, + states=(); + initial_params=nothing, + kwargs..., ) # End recursion if isempty(varname_vecs) && isempty(samplers) @@ -450,7 +516,7 @@ function gibbs_initialstep_recursive( conditioned_model, context = make_conditional(model, varnames, vi) # Take initial step with the current sampler. - _, new_state = AbstractMCMC.step( + _, new_state = step_function( rng, conditioned_model, sampler; @@ -470,6 +536,7 @@ function gibbs_initialstep_recursive( return gibbs_initialstep_recursive( rng, model, + step_function, varname_vecs_tail, samplers_tail, vi, @@ -493,7 +560,29 @@ function AbstractMCMC.step( states = state.states @assert length(samplers) == length(state.states) - vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...) + vi, states = gibbs_step_recursive( + rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... + ) + return Transition(model, vi), GibbsState(vi, states) +end + +function AbstractMCMC.step_warmup( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; + kwargs..., +) + vi = varinfo(state) + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + states = state.states + @assert length(samplers) == length(state.states) + + vi, states = gibbs_step_recursive( + rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... + ) return Transition(model, vi), GibbsState(vi, states) end @@ -620,10 +709,14 @@ end """ Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same function on the tail, until there are no more samplers left. + +The `step_function` argument should always be either AbstractMCMC.step or +AbstractMCMC.step_warmup. """ function gibbs_step_recursive( rng::Random.AbstractRNG, model::DynamicPPL.Model, + step_function::Function, varname_vecs, samplers, states, @@ -657,7 +750,7 @@ function gibbs_step_recursive( state = setparams_varinfo!!(conditioned_model, sampler, state, vi) # Take a step with the local sampler. - new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...)) + new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...)) new_vi_local = varinfo(new_state) # Merge the latest values for all the variables in the current sampler. @@ -668,6 +761,7 @@ function gibbs_step_recursive( return gibbs_step_recursive( rng, model, + step_function, varname_vecs_tail, samplers_tail, states_tail, diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index a3e38f46a9..775bbdce3a 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -60,3 +60,30 @@ function AbstractMCMC.step( end return transition, state end + +function AbstractMCMC.step_warmup( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler; + kwargs..., +) + return AbstractMCMC.step_warmup(rng, model, sampler.sampler; kwargs...) +end + +function AbstractMCMC.step_warmup( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::RepeatSampler, + state; + kwargs..., +) + transition, state = AbstractMCMC.step_warmup( + rng, model, sampler.sampler, state; kwargs... + ) + for _ in 2:(sampler.num_repeat) + transition, state = AbstractMCMC.step_warmup( + rng, model, sampler.sampler, state; kwargs... + ) + end + return transition, state +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 66fe51632f..697073d0a5 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -268,6 +268,102 @@ end @test chain1.value == chain2.value end +@testset "Gibbs warmup" begin + # An inference algorithm, for testing purposes, that records how many warm-up steps + # and how many non-warm-up steps haven been taken. + mutable struct WarmupCounter <: Inference.InferenceAlgorithm + warmup_init_count::Int + non_warmup_init_count::Int + warmup_count::Int + non_warmup_count::Int + + WarmupCounter() = new(0, 0, 0, 0) + end + + Turing.Inference.drop_space(wuc::WarmupCounter) = wuc + Turing.Inference.getspace(::WarmupCounter) = () + Turing.Inference.isgibbscomponent(::WarmupCounter) = true + + # A trivial state that holds nothing but a VarInfo, to be used with WarmupCounter. + struct VarInfoState{T} + vi::T + end + + Turing.Inference.varinfo(state::VarInfoState) = state.vi + function Turing.Inference.setparams_varinfo!!( + ::DynamicPPL.Model, + ::DynamicPPL.Sampler, + ::VarInfoState, + params::DynamicPPL.AbstractVarInfo, + ) + return VarInfoState(params) + end + + function AbstractMCMC.step( + ::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:WarmupCounter}; + kwargs..., + ) + spl.alg.non_warmup_init_count += 1 + return Turing.Inference.Transition(nothing, 0.0), + VarInfoState(DynamicPPL.VarInfo(model)) + end + + function AbstractMCMC.step_warmup( + ::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:WarmupCounter}; + kwargs..., + ) + spl.alg.warmup_init_count += 1 + return Turing.Inference.Transition(nothing, 0.0), + VarInfoState(DynamicPPL.VarInfo(model)) + end + + function AbstractMCMC.step( + ::Random.AbstractRNG, + ::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:WarmupCounter}, + s::VarInfoState; + kwargs..., + ) + spl.alg.non_warmup_count += 1 + return Turing.Inference.Transition(nothing, 0.0), s + end + + function AbstractMCMC.step_warmup( + ::Random.AbstractRNG, + ::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:WarmupCounter}, + s::VarInfoState; + kwargs..., + ) + spl.alg.warmup_count += 1 + return Turing.Inference.Transition(nothing, 0.0), s + end + + @model f() = x ~ Normal() + m = f() + + num_samples = 10 + num_warmup = 3 + wuc = WarmupCounter() + sample(m, Gibbs(:x => wuc), num_samples; num_warmup=num_warmup) + @test wuc.warmup_init_count == 1 + @test wuc.non_warmup_init_count == 0 + @test wuc.warmup_count == num_warmup + @test wuc.non_warmup_count == num_samples - 1 + + num_reps = 2 + wuc = WarmupCounter() + sample(m, Gibbs(:x => RepeatSampler(wuc, num_reps)), num_samples; num_warmup=num_warmup) + @test wuc.warmup_init_count == 1 + @test wuc.non_warmup_init_count == 0 + @test wuc.warmup_count == num_warmup * num_reps + @test wuc.non_warmup_count == (num_samples - 1) * num_reps +end + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @info "Starting Gibbs tests with $adbackend" @testset "Deprecated Gibbs constructors" begin From af44a546ae63cc52189c56b7a19f561c3a46c2d9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Mar 2025 14:51:34 +0000 Subject: [PATCH 2/3] Bump patch version to 0.36.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c3285fc41b..459f11dcbd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.36.2" +version = "0.36.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From e2805a5049faf249777eee18833853a902897fb1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Mar 2025 17:23:42 +0000 Subject: [PATCH 3/3] Fix a Gibbs bug --- src/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 2e7e0478e7..45e6f93262 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -417,7 +417,7 @@ function initial_varinfo(rng, model, spl, initial_params) # Update the parameters if provided. if initial_params !== nothing - vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model) + vi = DynamicPPL.initialize_parameters!!(vi, initial_params, spl, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588