Skip to content

Commit 833cbbf

Browse files
committed
Implement ParamsWithStats and to_chains functions
1 parent 1b159a6 commit 833cbbf

File tree

9 files changed

+319
-86
lines changed

9 files changed

+319
-86
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.2
4+
5+
Added a new exported struct, `DynamicPPL.ParamsWithStats`, and a corresponding function `DynamicPPL.to_chains`, which automatically converts a collection of `ParamsWithStats` to a given Chains type.
6+
37
## 0.38.1
48

59
Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.1"
3+
version = "0.38.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,13 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508+
509+
### Converting VarInfos to chains
510+
511+
It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
512+
This can be accomplished with the following:
513+
514+
```@docs
515+
DynamicPPL.ParamsWithStats
516+
DynamicPPL.to_chains
517+
```

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,76 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
"""
40+
DynamicPPL.to_chains(
41+
::Type{MCMCChains.Chains},
42+
params_and_stats::AbstractArray{<:ParamsWithStats}
43+
)
44+
45+
Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+
"""
47+
function DynamicPPL.to_chains(
48+
::Type{MCMCChains.Chains},
49+
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
50+
)
51+
# Handle parameters
52+
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
53+
split_dicts = map(params_and_stats) do ps
54+
# Separate into individual VarNames.
55+
vn_leaves_and_vals = if isempty(ps.params)
56+
Tuple{DynamicPPL.VarName,Any}[]
57+
else
58+
iters = map(
59+
AbstractPPL.varname_and_value_leaves,
60+
keys(ps.params),
61+
values(ps.params),
62+
)
63+
mapreduce(collect, vcat, iters)
64+
end
65+
vn_leaves = map(first, vn_leaves_and_vals)
66+
vals = map(last, vn_leaves_and_vals)
67+
for vn_leaf in vn_leaves
68+
push!(all_vn_leaves, vn_leaf)
69+
end
70+
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
71+
end
72+
vn_leaves = collect(all_vn_leaves)
73+
param_vals = [
74+
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
75+
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
76+
]
77+
param_symbols = map(Symbol, vn_leaves)
78+
# Handle statistics
79+
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
80+
for ps in params_and_stats
81+
for k in keys(ps.stats)
82+
push!(stat_keys, k)
83+
end
84+
end
85+
stat_keys = collect(stat_keys)
86+
stat_vals = [
87+
get(params_and_stats[i, j].stats, key, missing) for
88+
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
89+
j in eachindex(axes(params_and_stats, 2))
90+
]
91+
# Construct name map and info
92+
name_map = (internals=stat_keys,)
93+
info = (
94+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
95+
zip(all_vn_leaves, param_symbols)
96+
),
97+
)
98+
# Concatenate parameter and statistic values
99+
vals = cat(param_vals, stat_vals; dims=2)
100+
symbols = vcat(param_symbols, stat_keys)
101+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
102+
end
103+
function DynamicPPL.to_chains(
104+
::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats}
105+
)
106+
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
107+
end
108+
39109
"""
40110
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41111
@@ -110,7 +180,6 @@ function DynamicPPL.predict(
110180
DynamicPPL.VarInfo(),
111181
(
112182
DynamicPPL.LogPriorAccumulator(),
113-
DynamicPPL.LogJacobianAccumulator(),
114183
DynamicPPL.LogLikelihoodAccumulator(),
115184
DynamicPPL.ValuesAsInModelAccumulator(false),
116185
),
@@ -129,23 +198,9 @@ function DynamicPPL.predict(
129198
varinfo,
130199
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
131200
)
132-
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
133-
varname_vals = mapreduce(
134-
collect,
135-
vcat,
136-
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
137-
)
138-
139-
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
201+
DynamicPPL.ParamsWithStats(varinfo, nothing)
140202
end
141-
142-
chain_result = reduce(
143-
MCMCChains.chainscat,
144-
[
145-
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
146-
chain_idx in 1:size(predictive_samples, 2)
147-
],
148-
)
203+
chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples)
149204
parameter_names = if include_all
150205
MCMCChains.names(chain_result, :parameters)
151206
else
@@ -164,45 +219,6 @@ function DynamicPPL.predict(
164219
)
165220
end
166221

167-
function _predictive_samples_to_arrays(predictive_samples)
168-
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
169-
170-
sample_dicts = map(predictive_samples) do sample
171-
varname_value_pairs = sample.varname_and_values
172-
varnames = map(first, varname_value_pairs)
173-
values = map(last, varname_value_pairs)
174-
for varname in varnames
175-
push!(variable_names_set, varname)
176-
end
177-
178-
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
179-
end
180-
181-
variable_names = collect(variable_names_set)
182-
variable_values = [
183-
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
184-
key in variable_names
185-
]
186-
187-
return variable_names, variable_values
188-
end
189-
190-
function _predictive_samples_to_chains(predictive_samples)
191-
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
192-
variable_names_symbols = map(Symbol, variable_names)
193-
194-
internal_parameters = [:lp]
195-
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
196-
197-
parameter_names = [variable_names_symbols; internal_parameters]
198-
parameter_values = hcat(variable_values, log_probabilities)
199-
parameter_values = MCMCChains.concretize(parameter_values)
200-
201-
return MCMCChains.Chains(
202-
parameter_values, parameter_names, (internals=internal_parameters,)
203-
)
204-
end
205-
206222
"""
207223
returned(model::Model, chain::MCMCChains.Chains)
208224

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129+
# Chain construction
130+
ParamsWithStats,
131+
to_chains,
129132
# Convenience macros
130133
@addlogprob!,
131134
value_iterator_from_chain,
@@ -194,6 +197,7 @@ include("model_utils.jl")
194197
include("extract_priors.jl")
195198
include("values_as_in_model.jl")
196199
include("bijector.jl")
200+
include("to_chains.jl")
197201

198202
include("debug_utils.jl")
199203
using .DebugUtils

src/to_chains.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
ParamsWithStats
3+
4+
A struct which contains parameter values extracted from a `VarInfo`, along with any
5+
statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are
6+
optional.
7+
8+
ParamsWithStats(
9+
varinfo::AbstractVarInfo,
10+
model::Model,
11+
stats::NamedTuple=NamedTuple();
12+
include_colon_eq::Bool=true,
13+
include_log_probs::Bool=true,
14+
)
15+
16+
Generate a `ParamsWithStats` by re-evaluating the given `model` with the provided `varinfo`.
17+
Re-evaluation of the model is often necessary to obtain correct parameter values as well as
18+
log probabilities. This is especially true when using linked VarInfos, i.e., when variables
19+
have been transformed to unconstrained space, and if this is not done, subtle correctness
20+
bugs may arise: see, e.g., https://github.com/TuringLang/Turing.jl/issues/2195.
21+
22+
`include_colon_eq` controls whether variables on the left-hand side of `:=` are included in
23+
the resulting parameters.
24+
25+
`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log
26+
joint) are added to the resulting statistics NamedTuple.
27+
28+
ParamsWithStats(
29+
varinfo::AbstractVarInfo,
30+
::Nothing,
31+
stats::NamedTuple=NamedTuple();
32+
include_log_probs::Bool=true,
33+
)
34+
35+
There is one case where re-evaluation is not necessary, which is when the VarInfos all
36+
already contain `DynamicPPL.ValuesAsInModelAccumulator`. This accumulator stores values
37+
as seen during the model evaluation, so the values can be simply read off. In this case,
38+
`model` can be set to `nothing`, and no re-evaluation will be performed. However, it is the
39+
caller's responsibility to ensure that `ValuesAsInModelAccumulator` is indeed
40+
present.
41+
42+
`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log
43+
joint) are added to the resulting statistics NamedTuple.
44+
"""
45+
struct ParamsWithStats{P<:OrderedDict{VarName,Any},S<:NamedTuple}
46+
params::P
47+
stats::S
48+
49+
function ParamsWithStats(
50+
varinfo::AbstractVarInfo,
51+
model::DynamicPPL.Model,
52+
stats::NamedTuple=NamedTuple();
53+
include_colon_eq::Bool=true,
54+
include_log_probs::Bool=true,
55+
)
56+
accs = if include_log_probs
57+
(
58+
DynamicPPL.LogPriorAccumulator(),
59+
DynamicPPL.LogLikelihoodAccumulator(),
60+
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
61+
)
62+
else
63+
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
64+
end
65+
varinfo = DynamicPPL.setaccs!!(varinfo, accs)
66+
varinfo = last(DynamicPPL.evaluate!!(model, varinfo))
67+
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
68+
if include_log_probs
69+
stats = merge(
70+
stats,
71+
(
72+
logprior=DynamicPPL.getlogprior(varinfo),
73+
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
74+
lp=DynamicPPL.getlogjoint(varinfo),
75+
),
76+
)
77+
end
78+
return new{typeof(params),typeof(stats)}(params, stats)
79+
end
80+
81+
function ParamsWithStats(
82+
varinfo::AbstractVarInfo,
83+
::Nothing,
84+
stats::NamedTuple=NamedTuple();
85+
include_log_probs::Bool=true,
86+
)
87+
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
88+
if include_log_probs
89+
has_prior_acc = DynamicPPL.hasacc(varinfo, Val(:LogPrior))
90+
has_likelihood_acc = DynamicPPL.hasacc(varinfo, Val(:LogLikelihood))
91+
if has_prior_acc
92+
stats = merge(stats, (logprior=DynamicPPL.getlogprior(varinfo),))
93+
end
94+
if has_likelihood_acc
95+
stats = merge(stats, (loglikelihood=DynamicPPL.getloglikelihood(varinfo),))
96+
end
97+
if has_prior_acc && has_likelihood_acc
98+
stats = merge(stats, (logjoint=DynamicPPL.getlogjoint(varinfo),))
99+
end
100+
end
101+
return new{typeof(params),typeof(stats)}(params, stats)
102+
end
103+
end
104+
105+
"""
106+
to_chains(
107+
Tout::Type{<:AbstractChains},
108+
params_and_stats::AbstractArray{<:ParamsWithStats}
109+
)
110+
111+
Convert an array of `ParamsWithStats` to a chains object of type `Tout`.
112+
113+
This function is not implemented here but rather in package extensions for individual chains
114+
packages.
115+
"""
116+
function to_chains end

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,43 @@
1111
chain_generated = @test_nowarn returned(model, chain)
1212
@test size(chain_generated) == (1000, 1)
1313
@test mean(chain_generated) 0 atol = 0.1
14+
15+
@testset "varinfos_to_chains" begin
16+
@model function f(z)
17+
x ~ Normal()
18+
y := x + 1
19+
return z ~ Normal(y)
20+
end
21+
22+
z = 1.0
23+
model = f(z)
24+
25+
@testset "vector" begin
26+
ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50]
27+
c = DynamicPPL.to_chains(MCMCChains.Chains, ps)
28+
@test c isa MCMCChains.Chains
29+
@test size(c, 1) == 50
30+
@test size(c, 3) == 1
31+
@test Set(c.name_map.parameters) == Set([:x, :y])
32+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
33+
@test logpdf.(Normal(), c[:x]) c[:logprior]
34+
@test c.info.varname_to_symbol[@varname(x)] == :x
35+
@test c.info.varname_to_symbol[@varname(y)] == :y
36+
end
37+
38+
@testset "matrix" begin
39+
ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50, _ in 1:3]
40+
c = DynamicPPL.to_chains(MCMCChains.Chains, ps)
41+
@test c isa MCMCChains.Chains
42+
@test size(c, 1) == 50
43+
@test size(c, 3) == 3
44+
@test Set(c.name_map.parameters) == Set([:x, :y])
45+
@test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp])
46+
@test logpdf.(Normal(), c[:x]) c[:logprior]
47+
@test c.info.varname_to_symbol[@varname(x)] == :x
48+
@test c.info.varname_to_symbol[@varname(y)] == :y
49+
end
50+
end
1451
end
1552

1653
# test for `predict` is in `test/model.jl`

0 commit comments

Comments
 (0)