Skip to content

Commit 2fecc74

Browse files
committed
Add from_chains as well
1 parent 9884f7a commit 2fecc74

File tree

4 files changed

+135
-3
lines changed

4 files changed

+135
-3
lines changed

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,11 @@ This can be accomplished with the following:
515515
DynamicPPL.ParamsWithStats
516516
DynamicPPL.to_chains
517517
```
518+
519+
Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:
520+
521+
```@docs
522+
DynamicPPL.from_chains
523+
```
524+
525+
This is useful if you want to use the result of a chain in further model evaluations.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,43 @@ function DynamicPPL.to_chains(
106106
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
107107
end
108108

109+
function DynamicPPL.from_chains(
110+
::Type{T}, chain::MCMCChains.Chains
111+
) where {T<:AbstractDict{<:DynamicPPL.VarName}}
112+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
113+
matrix = map(idxs) do (sample_idx, chain_idx)
114+
d = T()
115+
for vn in DynamicPPL.varnames(chain)
116+
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
117+
end
118+
d
119+
end
120+
return matrix
121+
end
122+
function DynamicPPL.from_chains(::Type{NamedTuple}, chain::MCMCChains.Chains)
123+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
124+
matrix = map(idxs) do (sample_idx, chain_idx)
125+
get(chain[sample_idx, :, chain_idx], keys(chain); flatten=true)
126+
end
127+
return matrix
128+
end
129+
function DynamicPPL.from_chains(
130+
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
131+
)
132+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
133+
internals_chain = MCMCChains.get_sections(chain, :internals)
134+
params = DynamicPPL.from_chains(
135+
DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,eltype(chain.value)},
136+
chain,
137+
)
138+
stats = DynamicPPL.from_chains(NamedTuple, internals_chain)
139+
return map(idxs) do (sample_idx, chain_idx)
140+
DynamicPPL.ParamsWithStats(
141+
params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
142+
)
143+
end
144+
end
145+
109146
"""
110147
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
111148

src/to_chains.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,16 @@ present.
4242
`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log
4343
joint) are added to the resulting statistics NamedTuple.
4444
"""
45-
struct ParamsWithStats{P<:OrderedDict{VarName,Any},S<:NamedTuple}
45+
struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple}
4646
params::P
4747
stats::S
4848

49+
function ParamsWithStats(
50+
params::P, stats::S
51+
) where {P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple}
52+
return new{P,S}(params, stats)
53+
end
54+
4955
function ParamsWithStats(
5056
varinfo::AbstractVarInfo,
5157
model::DynamicPPL.Model,
@@ -113,11 +119,40 @@ maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi
113119
to_chains(
114120
Tout::Type{<:AbstractChains},
115121
params_and_stats::AbstractArray{<:ParamsWithStats}
116-
)
122+
)::Tout
117123
118124
Convert an array of `ParamsWithStats` to a chains object of type `Tout`.
119125
120126
This function is not implemented here but rather in package extensions for individual chains
121127
packages.
122128
"""
123129
function to_chains end
130+
131+
"""
132+
from_chains(
133+
::Type{Tout},
134+
chain::AbstractChains
135+
)::AbstractMatrix{<:Tout}
136+
137+
Convert a chains object to an array of size (niters * nchains) with element type `Tout`.
138+
139+
Note that even if `chain` contains only a single chain, this is still returned as a matrix,
140+
not a vector.
141+
142+
This function is not implemented here but rather in package extensions for individual chains
143+
packages.
144+
145+
Common implementations include:
146+
147+
- `Tout = ParamsWithStats`: obtain both parameters and statistics
148+
- `Tout <: AbstractDict{<:VarName}`: obtain the parameters only (since stats are not stored
149+
as `VarName`s
150+
- `Tout = NamedTuple`: obtain both parameters and statistics as a NamedTuple
151+
152+
!!! warning
153+
Note that `Tout = NamedTuple` potentially causes a loss of information especially when
154+
used with `MCMCChains.Chains`, since variable names are not preserved. This may lead to
155+
bugs if the NamedTuple is later used for other purposes, such as evaluating a model. To
156+
avoid this, you should always use something like `Tout = OrderedDict{VarName,Any}`.
157+
"""
158+
function from_chains end

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@test size(chain_generated) == (1000, 1)
1313
@test mean(chain_generated) 0 atol = 0.1
1414

15-
@testset "varinfos_to_chains" begin
15+
@testset "to_chains" begin
1616
@model function f(z)
1717
x ~ Normal()
1818
y := x + 1
@@ -48,6 +48,58 @@
4848
@test c.info.varname_to_symbol[@varname(y)] == :y
4949
end
5050
end
51+
52+
@testset "from_chains" begin
53+
@model function f(z)
54+
x ~ Normal()
55+
y := x + 1
56+
return z ~ Normal(y)
57+
end
58+
59+
z = 1.0
60+
model = f(z)
61+
ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50]
62+
c = DynamicPPL.to_chains(MCMCChains.Chains, ps)
63+
64+
@testset "OrderedDict" begin
65+
arr_dicts = DynamicPPL.from_chains(OrderedDict{VarName,Any}, c)
66+
@test size(arr_dicts) == (50, 1)
67+
for i in 1:50
68+
dict = arr_dicts[i, 1]
69+
@test dict isa OrderedDict{VarName,Any}
70+
p = ps[i].params
71+
@test dict[@varname(x)] == p[@varname(x)]
72+
@test dict[@varname(y)] == p[@varname(y)]
73+
@test length(dict) == 2
74+
end
75+
end
76+
77+
@testset "NamedTuple" begin
78+
arr_nts = DynamicPPL.from_chains(NamedTuple, c)
79+
@test size(arr_nts) == (50, 1)
80+
for i in 1:50
81+
nt = arr_nts[i, 1]
82+
@test length(nt) == 5
83+
p = ps[i]
84+
@test nt.x == p.params[@varname(x)]
85+
@test nt.y == p.params[@varname(y)]
86+
@test nt.lp == p.stats.lp
87+
@test nt.logprior == p.stats.logprior
88+
@test nt.loglikelihood == p.stats.loglikelihood
89+
end
90+
end
91+
92+
@testset "ParamsWithStats" begin
93+
arr_pss = DynamicPPL.from_chains(ParamsWithStats, c)
94+
@test size(arr_pss) == (50, 1)
95+
for i in 1:50
96+
new_p = arr_pss[i, 1]
97+
p = ps[i]
98+
@test new_p.params == p.params
99+
@test new_p.stats == p.stats
100+
end
101+
end
102+
end
51103
end
52104

53105
# test for `predict` is in `test/model.jl`

0 commit comments

Comments
 (0)