Skip to content

Commit 96df1a4

Browse files
committed
Include more options
1 parent a7b22e4 commit 96df1a4

File tree

4 files changed

+152
-40
lines changed

4 files changed

+152
-40
lines changed

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,11 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508+
509+
### Converting VarInfo to chains
510+
511+
The following function is useful for package developers seeking to extend DynamicPPL:
512+
513+
```@docs
514+
DynamicPPL.varinfos_to_chains
515+
```

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,51 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
struct ModelEvaluationData{
40+
Tval<:AbstractDict,Tpr<:Union{AbstractFloat,Missing},Tlk<:Union{AbstractFloat,Missing}
41+
}
42+
values::Tval
43+
logprior::Tpr
44+
loglikelihood::Tlk
45+
end
46+
3947
"""
4048
DynamicPPL.varinfos_to_chains(
4149
::Type{MCMCChains.Chains},
42-
model::Model,
4350
varinfos::AbstractArray{<:AbstractVarInfo},
51+
model::Union{DynamicPPL.Model,Nothing};
4452
include_colon_eq::Bool=true,
53+
include_log_probs::Bool=true,
4554
)
4655
4756
Convert an array of `VarInfo`s to an `MCMCChains.Chains` object.
4857
"""
4958
function DynamicPPL.varinfos_to_chains(
5059
::Type{MCMCChains.Chains},
51-
model::DynamicPPL.Model,
5260
varinfos::AbstractMatrix{<:DynamicPPL.AbstractVarInfo},
61+
model::Union{DynamicPPL.Model,Nothing};
5362
include_colon_eq::Bool=true,
63+
include_log_probs::Bool=true,
5464
)
5565
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
56-
# Re-evaluate model
57-
param_dicts = map(varinfos) do vi
58-
# Dict{VarName, Any}
59-
full_vals = DynamicPPL.values_as_in_model(model, include_colon_eq, vi)
66+
evaluation_data = map(varinfos) do vi
67+
# Re-evaluate model if needed
68+
new_vi = if model isa DynamicPPL.Model
69+
accs = (
70+
DynamicPPL.LogPriorAccumulator(),
71+
DynamicPPL.LogJacobianAccumulator(),
72+
DynamicPPL.LogLikelihoodAccumulator(),
73+
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
74+
)
75+
vi = DynamicPPL.setaccs!!(vi, accs)
76+
last(DynamicPPL.evaluate!!(model, vi))
77+
else
78+
vi
79+
end
80+
full_vals = DynamicPPL.getacc(new_vi, Val(:ValuesAsInModel)).values
6081
# Separate into individual VarNames.
6182
vn_leaves_and_vals = if isempty(full_vals)
62-
Tuple{VarName,Any}[]
83+
Tuple{DynamicPPL.VarName,Any}[]
6384
else
6485
iters = map(
6586
AbstractPPL.varname_and_value_leaves,
@@ -73,29 +94,58 @@ function DynamicPPL.varinfos_to_chains(
7394
for vn_leaf in vn_leaves
7495
push!(all_vn_leaves, vn_leaf)
7596
end
76-
return DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
97+
vn_leaves_dict = DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
98+
logprior = if include_log_probs && DynamicPPL.hasacc(new_vi, Val(:LogPrior))
99+
DynamicPPL.getacc(new_vi, Val(:LogPrior)).logp
100+
else
101+
missing
102+
end
103+
loglikelihood =
104+
if include_log_probs && DynamicPPL.hasacc(new_vi, Val(:LogLikelihood))
105+
DynamicPPL.getacc(new_vi, Val(:LogLikelihood)).logp
106+
else
107+
missing
108+
end
109+
return ModelEvaluationData(vn_leaves_dict, logprior, loglikelihood)
77110
end
78111
vn_leaves = collect(all_vn_leaves)
79112
vals = [
80-
get(param_dicts[i, j], key, missing) for i in eachindex(axes(param_dicts, 1)),
81-
key in vn_leaves, j in eachindex(axes(param_dicts, 2))
113+
get(evaluation_data[i, j].values, key, missing) for
114+
i in eachindex(axes(evaluation_data, 1)), key in vn_leaves,
115+
j in eachindex(axes(evaluation_data, 2))
82116
]
83117
symbols = map(Symbol, vn_leaves)
84118
info = (
85119
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
86120
zip(all_vn_leaves, symbols)
87121
),
88122
)
89-
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols; info=info)
123+
name_map = NamedTuple()
124+
if include_log_probs
125+
logpriors = map(e -> e.logprior, evaluation_data)
126+
loglikelihoods = map(e -> e.loglikelihood, evaluation_data)
127+
logjoints = map(e -> e.logprior + e.loglikelihood, evaluation_data)
128+
lps = permutedims(stack([logpriors, loglikelihoods, logjoints]), (1, 3, 2))
129+
vals = hcat(vals, lps)
130+
lp_symbols = [:logprior, :loglikelihood, :lp]
131+
append!(symbols, lp_symbols)
132+
name_map = (internals=lp_symbols,)
133+
end
134+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
90135
end
91136
function DynamicPPL.varinfos_to_chains(
92137
::Type{MCMCChains.Chains},
93-
model::DynamicPPL.Model,
94138
varinfos::AbstractVector{<:DynamicPPL.AbstractVarInfo},
139+
model::Union{DynamicPPL.Model,Nothing};
95140
include_colon_eq::Bool=true,
141+
include_log_probs::Bool=true,
96142
)
97143
return DynamicPPL.varinfos_to_chains(
98-
MCMCChains.Chains, model, hcat(varinfos), include_colon_eq
144+
MCMCChains.Chains,
145+
hcat(varinfos),
146+
model;
147+
include_colon_eq=include_colon_eq,
148+
include_log_probs=include_log_probs,
99149
)
100150
end
101151

src/abstract_varinfo.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,19 +1186,34 @@ end
11861186
"""
11871187
varinfos_to_chains(
11881188
Tout::Type{<:AbstractChains},
1189-
model::DynamicPPL.Model,
11901189
varinfos::AbstractArray{<:AbstractVarInfo},
1191-
include_colon_eq::Bool=true
1190+
model::Union{DynamicPPL.Model,Nothing};
1191+
include_colon_eq::Bool=true,
1192+
include_log_probs::Bool=true,
11921193
)
11931194
11941195
Convert an array of `varinfos` to a chains object of type `Tout`.
11951196
1196-
The `model` is required in order to account for cases where the varinfo is linked and
1197-
re-evaluation is required. For example, this is the case when the support of a distribution
1198-
depends on other random variables.
1197+
In many cases, it is necessary to re-evaluate the model with each VarInfo to obtain correct
1198+
parameter values as well as log probabilities. This is especially true when using linked
1199+
VarInfos, i.e., when variables have been transformed to unconstrained space, and if this is
1200+
not done, subtle correctness bugs may arise: see, e.g.,
1201+
https://github.com/TuringLang/Turing.jl/issues/2195
1202+
1203+
There is one case where re-evaluation is not necessary, which is when the VarInfos all
1204+
already contain `DynamicPPL.ValuesAsInModelAccumulator`. This accumulator stores values
1205+
as seen during the model evaluation, so the values can be simply read off. In this case,
1206+
`model` can be set to `nothing`, and no re-evaluation will be performed. However, it is the
1207+
caller's responsibility to ensure that `ValuesAsInModelAccumulator` is indeed
1208+
present.
11991209
12001210
`include_colon_eq` indicates whether to include variables on the left-hand side of `:=`.
12011211
1212+
`include_log_probs` indicates whether to include log probabilities (log prior, log
1213+
likelihood, and log joint) in the resulting chains object. By default these are included
1214+
(and also recalculated if a model is provided; this incurs almost no additional cost on top
1215+
of the original re-evaluation).
1216+
12021217
This function is not implemented here but rather in package extensions for individual chains
12031218
packages.
12041219
"""

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,59 +15,98 @@ using LinearAlgebra: I
1515
@test mean(chain_generated) 0 atol = 0.1
1616

1717
@testset "varinfos_to_chains" begin
18-
@model function f()
18+
@model function f(z)
1919
x ~ Normal()
2020
y := x + 1
21-
return z ~ MvNormal(zeros(3), I)
21+
return z ~ Normal(y)
2222
end
2323

24-
model = f()
24+
z = 1.0
25+
model = f(z)
2526

2627
@testset "vector" begin
2728
vis = [VarInfo(model) for _ in 1:50]
28-
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis)
29+
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, vis, model)
2930
@test c isa MCMCChains.Chains
30-
@test size(c) == (50, 5, 1)
31+
@test size(c, 1) == 50
32+
@test size(c, 3) == 1
33+
@test Set(c.name_map.parameters) == Set([:x, :y])
34+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
35+
@test logpdf.(Normal(), c[:x]) c[:logprior]
36+
@test c.info.varname_to_symbol[@varname(x)] == :x
37+
@test c.info.varname_to_symbol[@varname(y)] == :y
38+
end
39+
40+
@testset "vector, no reevaluation" begin
41+
# check that it throws an error without VAIMAcc
42+
vis = [VarInfo(model) for _ in 1:50]
43+
@test_throws ErrorException DynamicPPL.varinfos_to_chains(
44+
MCMCChains.Chains, vis, nothing
45+
)
46+
# and that it works with VAIMAcc
47+
vi = DynamicPPL.VarInfo(model)
48+
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.ValuesAsInModelAccumulator(true),))
49+
vis = [last(DynamicPPL.init!!(model, vi)) for _ in 1:50]
50+
c = DynamicPPL.varinfos_to_chains(
51+
MCMCChains.Chains, vis, nothing; include_log_probs=false
52+
)
53+
@test c isa MCMCChains.Chains
54+
@test size(c, 1) == 50
55+
@test size(c, 3) == 1
56+
@test Set(c.name_map.parameters) == Set([:x, :y])
3157
@test c.info.varname_to_symbol[@varname(x)] == :x
3258
@test c.info.varname_to_symbol[@varname(y)] == :y
33-
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
34-
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
35-
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
3659
end
3760

3861
@testset "vector, without include_colon_eq" begin
3962
vis = [VarInfo(model) for _ in 1:50]
40-
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis, false)
63+
c = DynamicPPL.varinfos_to_chains(
64+
MCMCChains.Chains, vis, model; include_colon_eq=false
65+
)
4166
@test c isa MCMCChains.Chains
42-
@test size(c) == (50, 4, 1)
67+
@test size(c, 1) == 50
68+
@test size(c, 3) == 1
69+
@test Set(c.name_map.parameters) == Set([:x])
70+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
71+
@test logpdf.(Normal(), c[:x]) c[:logprior]
4372
@test c.info.varname_to_symbol[@varname(x)] == :x
44-
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
45-
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
46-
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
73+
end
74+
75+
@testset "vector, without logprobs" begin
76+
vis = [VarInfo(model) for _ in 1:50]
77+
c = DynamicPPL.varinfos_to_chains(
78+
MCMCChains.Chains, vis, model; include_log_probs=false
79+
)
80+
@test c isa MCMCChains.Chains
81+
@test size(c, 1) == 50
82+
@test size(c, 3) == 1
83+
@test Set(c.name_map.parameters) == Set([:x, :y])
84+
@test !haskey(c.name_map, :internals)
4785
end
4886

4987
@testset "Different VarInfo type" begin
5088
vis = [SimpleVarInfo(model) for _ in 1:50]
51-
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis)
89+
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, vis, model)
5290
@test c isa MCMCChains.Chains
53-
@test size(c) == (50, 5, 1)
91+
@test size(c, 1) == 50
92+
@test size(c, 3) == 1
93+
@test Set(c.name_map.parameters) == Set([:x, :y])
94+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
95+
@test logpdf.(Normal(), c[:x]) c[:logprior]
5496
@test c.info.varname_to_symbol[@varname(x)] == :x
5597
@test c.info.varname_to_symbol[@varname(y)] == :y
56-
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
57-
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
58-
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
5998
end
6099

61100
@testset "matrix" begin
62101
vis = [VarInfo(model) for _ in 1:50, _ in 1:3]
63-
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis)
102+
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, vis, model)
64103
@test c isa MCMCChains.Chains
65104
@test size(c) == (50, 5, 3)
105+
@test Set(c.name_map.parameters) == Set([:x, :y])
106+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
107+
@test logpdf.(Normal(), c[:x]) c[:logprior]
66108
@test c.info.varname_to_symbol[@varname(x)] == :x
67109
@test c.info.varname_to_symbol[@varname(y)] == :y
68-
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
69-
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
70-
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
71110
end
72111
end
73112
end

0 commit comments

Comments
 (0)