Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.36.2"
version = "0.36.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
110 changes: 102 additions & 8 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,75 @@

varinfo(state::GibbsState) = state.vi

function DynamicPPL.initialstep(
"""
Initialise a VarInfo for the Gibbs sampler.

This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to
support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
incompatible with step_warmup.
"""
function initial_varinfo(rng, model, spl, initial_params)
vi = DynamicPPL.default_varinfo(rng, model, spl)

# Update the parameters if provided.
if initial_params !== nothing
vi = DynamicPPL.initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext()))
end
return vi
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi::DynamicPPL.AbstractVarInfo;
spl::DynamicPPL.Sampler{<:Gibbs};
initial_params=nothing,
kwargs...,
)
alg = spl.alg
varnames = alg.varnames
samplers = alg.samplers
vi = initial_varinfo(rng, model, spl, initial_params)

vi, states = gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
rng,
model,
AbstractMCMC.step,
varnames,
samplers,
vi;
initial_params=initial_params,
kwargs...,
)
return Transition(model, vi), GibbsState(vi, states)
end

function AbstractMCMC.step_warmup(

Check warning on line 456 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L456

Added line #L456 was not covered by tests
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs};
initial_params=nothing,
kwargs...,
)
alg = spl.alg
varnames = alg.varnames
samplers = alg.samplers
vi = initial_varinfo(rng, model, spl, initial_params)

Check warning on line 466 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L463-L466

Added lines #L463 - L466 were not covered by tests

vi, states = gibbs_initialstep_recursive(

Check warning on line 468 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L468

Added line #L468 was not covered by tests
rng,
model,
AbstractMCMC.step_warmup,
varnames,
samplers,
vi;
initial_params=initial_params,
kwargs...,
)
return Transition(model, vi), GibbsState(vi, states)
end
Expand All @@ -427,9 +482,20 @@
Take the first step of MCMC for the first component sampler, and call the same function
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
and a tuple of initial states for all component samplers.

The `step_function` argument should always be either AbstractMCMC.step or
AbstractMCMC.step_warmup.
"""
function gibbs_initialstep_recursive(
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
rng,
model,
step_function::Function,
varname_vecs,
samplers,
vi,
states=();
initial_params=nothing,
kwargs...,
)
# End recursion
if isempty(varname_vecs) && isempty(samplers)
Expand All @@ -450,7 +516,7 @@
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
_, new_state = step_function(
rng,
conditioned_model,
sampler;
Expand All @@ -470,6 +536,7 @@
return gibbs_initialstep_recursive(
rng,
model,
step_function,
varname_vecs_tail,
samplers_tail,
vi,
Expand All @@ -493,7 +560,29 @@
states = state.states
@assert length(samplers) == length(state.states)

vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
vi, states = gibbs_step_recursive(
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
)
return Transition(model, vi), GibbsState(vi, states)
end

function AbstractMCMC.step_warmup(

Check warning on line 569 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L569

Added line #L569 was not covered by tests
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
state::GibbsState;
kwargs...,
)
vi = varinfo(state)
alg = spl.alg
varnames = alg.varnames
samplers = alg.samplers
states = state.states
@assert length(samplers) == length(state.states)

Check warning on line 581 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L576-L581

Added lines #L576 - L581 were not covered by tests

vi, states = gibbs_step_recursive(

Check warning on line 583 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L583

Added line #L583 was not covered by tests
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
)
return Transition(model, vi), GibbsState(vi, states)
end

Expand Down Expand Up @@ -620,10 +709,14 @@
"""
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
function on the tail, until there are no more samplers left.

The `step_function` argument should always be either AbstractMCMC.step or
AbstractMCMC.step_warmup.
"""
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
step_function::Function,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this create some kind of type instability? my first order thought is that Function is abstract type

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine, since each function has its own type. E.g.

julia> typeof(identity)
typeof(identity) (singleton type of function identity, subtype of Function)

julia> typeof(sin)
typeof(sin) (singleton type of function sin, subtype of Function)

The ::Function bit just enforces that you can't pass as the step_function argument anything that isn't of type Function, but the compiler will still see the concrete type of the argument. I also checked that making this change didn't have a substantial impact on runtime.

varname_vecs,
samplers,
states,
Expand Down Expand Up @@ -657,7 +750,7 @@
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)

# Take a step with the local sampler.
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))
new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...))

new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
Expand All @@ -668,6 +761,7 @@
return gibbs_step_recursive(
rng,
model,
step_function,
varname_vecs_tail,
samplers_tail,
states_tail,
Expand Down
27 changes: 27 additions & 0 deletions src/mcmc/repeat_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,30 @@
end
return transition, state
end

function AbstractMCMC.step_warmup(

Check warning on line 64 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L64

Added line #L64 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
sampler::RepeatSampler;
kwargs...,
)
return AbstractMCMC.step_warmup(rng, model, sampler.sampler; kwargs...)

Check warning on line 70 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L70

Added line #L70 was not covered by tests
end

function AbstractMCMC.step_warmup(

Check warning on line 73 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L73

Added line #L73 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
sampler::RepeatSampler,
state;
kwargs...,
)
transition, state = AbstractMCMC.step_warmup(

Check warning on line 80 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L80

Added line #L80 was not covered by tests
rng, model, sampler.sampler, state; kwargs...
)
for _ in 2:(sampler.num_repeat)
transition, state = AbstractMCMC.step_warmup(

Check warning on line 84 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L83-L84

Added lines #L83 - L84 were not covered by tests
rng, model, sampler.sampler, state; kwargs...
)
end
return transition, state

Check warning on line 88 in src/mcmc/repeat_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/repeat_sampler.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
end
96 changes: 96 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,102 @@ end
@test chain1.value == chain2.value
end

@testset "Gibbs warmup" begin
# An inference algorithm, for testing purposes, that records how many warm-up steps
# and how many non-warm-up steps haven been taken.
mutable struct WarmupCounter <: Inference.InferenceAlgorithm
warmup_init_count::Int
non_warmup_init_count::Int
warmup_count::Int
non_warmup_count::Int

WarmupCounter() = new(0, 0, 0, 0)
end

Turing.Inference.drop_space(wuc::WarmupCounter) = wuc
Turing.Inference.getspace(::WarmupCounter) = ()
Turing.Inference.isgibbscomponent(::WarmupCounter) = true

# A trivial state that holds nothing but a VarInfo, to be used with WarmupCounter.
struct VarInfoState{T}
vi::T
end

Turing.Inference.varinfo(state::VarInfoState) = state.vi
function Turing.Inference.setparams_varinfo!!(
::DynamicPPL.Model,
::DynamicPPL.Sampler,
::VarInfoState,
params::DynamicPPL.AbstractVarInfo,
)
return VarInfoState(params)
end

function AbstractMCMC.step(
::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:WarmupCounter};
kwargs...,
)
spl.alg.non_warmup_init_count += 1
return Turing.Inference.Transition(nothing, 0.0),
VarInfoState(DynamicPPL.VarInfo(model))
end

function AbstractMCMC.step_warmup(
::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:WarmupCounter};
kwargs...,
)
spl.alg.warmup_init_count += 1
return Turing.Inference.Transition(nothing, 0.0),
VarInfoState(DynamicPPL.VarInfo(model))
end

function AbstractMCMC.step(
::Random.AbstractRNG,
::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:WarmupCounter},
s::VarInfoState;
kwargs...,
)
spl.alg.non_warmup_count += 1
return Turing.Inference.Transition(nothing, 0.0), s
end

function AbstractMCMC.step_warmup(
::Random.AbstractRNG,
::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:WarmupCounter},
s::VarInfoState;
kwargs...,
)
spl.alg.warmup_count += 1
return Turing.Inference.Transition(nothing, 0.0), s
end

@model f() = x ~ Normal()
m = f()

num_samples = 10
num_warmup = 3
wuc = WarmupCounter()
sample(m, Gibbs(:x => wuc), num_samples; num_warmup=num_warmup)
@test wuc.warmup_init_count == 1
@test wuc.non_warmup_init_count == 0
@test wuc.warmup_count == num_warmup
@test wuc.non_warmup_count == num_samples - 1

num_reps = 2
wuc = WarmupCounter()
sample(m, Gibbs(:x => RepeatSampler(wuc, num_reps)), num_samples; num_warmup=num_warmup)
@test wuc.warmup_init_count == 1
@test wuc.non_warmup_init_count == 0
@test wuc.warmup_count == num_warmup * num_reps
@test wuc.non_warmup_count == (num_samples - 1) * num_reps
end

@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
@info "Starting Gibbs tests with $adbackend"
@testset "Deprecated Gibbs constructors" begin
Expand Down
Loading