@@ -36,6 +36,76 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39+ """
40+ DynamicPPL.to_chains(
41+ ::Type{MCMCChains.Chains},
42+ params_and_stats::AbstractArray{<:ParamsWithStats}
43+ )
44+
45+ Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+ """
47+ function DynamicPPL. to_chains (
48+ :: Type{MCMCChains.Chains} ,
49+ params_and_stats:: AbstractMatrix{<:DynamicPPL.ParamsWithStats} ,
50+ )
51+ # Handle parameters
52+ all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
53+ split_dicts = map (params_and_stats) do ps
54+ # Separate into individual VarNames.
55+ vn_leaves_and_vals = if isempty (ps. params)
56+ Tuple{DynamicPPL. VarName,Any}[]
57+ else
58+ iters = map (
59+ AbstractPPL. varname_and_value_leaves,
60+ keys (ps. params),
61+ values (ps. params),
62+ )
63+ mapreduce (collect, vcat, iters)
64+ end
65+ vn_leaves = map (first, vn_leaves_and_vals)
66+ vals = map (last, vn_leaves_and_vals)
67+ for vn_leaf in vn_leaves
68+ push! (all_vn_leaves, vn_leaf)
69+ end
70+ DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
71+ end
72+ vn_leaves = collect (all_vn_leaves)
73+ param_vals = [
74+ get (split_dicts[i, j], key, missing ) for i in eachindex (axes (split_dicts, 1 )),
75+ key in vn_leaves, j in eachindex (axes (split_dicts, 2 ))
76+ ]
77+ param_symbols = map (Symbol, vn_leaves)
78+ # Handle statistics
79+ stat_keys = DynamicPPL. OrderedCollections. OrderedSet {Symbol} ()
80+ for ps in params_and_stats
81+ for k in keys (ps. stats)
82+ push! (stat_keys, k)
83+ end
84+ end
85+ stat_keys = collect (stat_keys)
86+ stat_vals = [
87+ get (params_and_stats[i, j]. stats, key, missing ) for
88+ i in eachindex (axes (params_and_stats, 1 )), key in stat_keys,
89+ j in eachindex (axes (params_and_stats, 2 ))
90+ ]
91+ # Construct name map and info
92+ name_map = (internals= stat_keys,)
93+ info = (
94+ varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
95+ zip (all_vn_leaves, param_symbols)
96+ ),
97+ )
98+ # Concatenate parameter and statistic values
99+ vals = cat (param_vals, stat_vals; dims= 2 )
100+ symbols = vcat (param_symbols, stat_keys)
101+ return MCMCChains. Chains (MCMCChains. concretize (vals), symbols, name_map; info= info)
102+ end
103+ function DynamicPPL. to_chains (
104+ :: Type{MCMCChains.Chains} , ps:: AbstractVector{<:DynamicPPL.ParamsWithStats}
105+ )
106+ return DynamicPPL. to_chains (MCMCChains. Chains, hcat (ps))
107+ end
108+
39109"""
40110 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41111
@@ -110,7 +180,6 @@ function DynamicPPL.predict(
110180 DynamicPPL. VarInfo (),
111181 (
112182 DynamicPPL. LogPriorAccumulator (),
113- DynamicPPL. LogJacobianAccumulator (),
114183 DynamicPPL. LogLikelihoodAccumulator (),
115184 DynamicPPL. ValuesAsInModelAccumulator (false ),
116185 ),
@@ -129,23 +198,9 @@ function DynamicPPL.predict(
129198 varinfo,
130199 DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
131200 )
132- vals = DynamicPPL. getacc (varinfo, Val (:ValuesAsInModel )). values
133- varname_vals = mapreduce (
134- collect,
135- vcat,
136- map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals)),
137- )
138-
139- return (varname_and_values= varname_vals, logp= DynamicPPL. getlogjoint (varinfo))
201+ DynamicPPL. ParamsWithStats (varinfo, nothing )
140202 end
141-
142- chain_result = reduce (
143- MCMCChains. chainscat,
144- [
145- _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
146- chain_idx in 1 : size (predictive_samples, 2 )
147- ],
148- )
203+ chain_result = DynamicPPL. to_chains (MCMCChains. Chains, predictive_samples)
149204 parameter_names = if include_all
150205 MCMCChains. names (chain_result, :parameters )
151206 else
@@ -164,45 +219,6 @@ function DynamicPPL.predict(
164219 )
165220end
166221
167- function _predictive_samples_to_arrays (predictive_samples)
168- variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
169-
170- sample_dicts = map (predictive_samples) do sample
171- varname_value_pairs = sample. varname_and_values
172- varnames = map (first, varname_value_pairs)
173- values = map (last, varname_value_pairs)
174- for varname in varnames
175- push! (variable_names_set, varname)
176- end
177-
178- return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
179- end
180-
181- variable_names = collect (variable_names_set)
182- variable_values = [
183- get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
184- key in variable_names
185- ]
186-
187- return variable_names, variable_values
188- end
189-
190- function _predictive_samples_to_chains (predictive_samples)
191- variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
192- variable_names_symbols = map (Symbol, variable_names)
193-
194- internal_parameters = [:lp ]
195- log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
196-
197- parameter_names = [variable_names_symbols; internal_parameters]
198- parameter_values = hcat (variable_values, log_probabilities)
199- parameter_values = MCMCChains. concretize (parameter_values)
200-
201- return MCMCChains. Chains (
202- parameter_values, parameter_names, (internals= internal_parameters,)
203- )
204- end
205-
206222"""
207223 returned(model::Model, chain::MCMCChains.Chains)
208224
0 commit comments