@@ -36,30 +36,51 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39+ struct ModelEvaluationData{
40+ Tval<: AbstractDict ,Tpr<: Union{AbstractFloat,Missing} ,Tlk<: Union{AbstractFloat,Missing}
41+ }
42+ values:: Tval
43+ logprior:: Tpr
44+ loglikelihood:: Tlk
45+ end
46+
3947"""
4048 DynamicPPL.varinfos_to_chains(
4149 ::Type{MCMCChains.Chains},
42- model::Model,
4350 varinfos::AbstractArray{<:AbstractVarInfo},
51+ model::Union{DynamicPPL.Model,Nothing};
4452 include_colon_eq::Bool=true,
53+ include_log_probs::Bool=true,
4554 )
4655
4756Convert an array of `VarInfo`s to an `MCMCChains.Chains` object.
4857"""
4958function DynamicPPL. varinfos_to_chains (
5059 :: Type{MCMCChains.Chains} ,
51- model:: DynamicPPL.Model ,
5260 varinfos:: AbstractMatrix{<:DynamicPPL.AbstractVarInfo} ,
61+ model:: Union{DynamicPPL.Model,Nothing} ;
5362 include_colon_eq:: Bool = true ,
63+ include_log_probs:: Bool = true ,
5464)
5565 all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
56- # Re-evaluate model
57- param_dicts = map (varinfos) do vi
58- # Dict{VarName, Any}
59- full_vals = DynamicPPL. values_as_in_model (model, include_colon_eq, vi)
66+ evaluation_data = map (varinfos) do vi
67+ # Re-evaluate model if needed
68+ new_vi = if model isa DynamicPPL. Model
69+ accs = (
70+ DynamicPPL. LogPriorAccumulator (),
71+ DynamicPPL. LogJacobianAccumulator (),
72+ DynamicPPL. LogLikelihoodAccumulator (),
73+ DynamicPPL. ValuesAsInModelAccumulator (include_colon_eq),
74+ )
75+ vi = DynamicPPL. setaccs!! (vi, accs)
76+ last (DynamicPPL. evaluate!! (model, vi))
77+ else
78+ vi
79+ end
80+ full_vals = DynamicPPL. getacc (new_vi, Val (:ValuesAsInModel )). values
6081 # Separate into individual VarNames.
6182 vn_leaves_and_vals = if isempty (full_vals)
62- Tuple{VarName,Any}[]
83+ Tuple{DynamicPPL . VarName,Any}[]
6384 else
6485 iters = map (
6586 AbstractPPL. varname_and_value_leaves,
@@ -73,29 +94,58 @@ function DynamicPPL.varinfos_to_chains(
7394 for vn_leaf in vn_leaves
7495 push! (all_vn_leaves, vn_leaf)
7596 end
76- return DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
97+ vn_leaves_dict = DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
98+ logprior = if include_log_probs && DynamicPPL. hasacc (new_vi, Val (:LogPrior ))
99+ DynamicPPL. getacc (new_vi, Val (:LogPrior )). logp
100+ else
101+ missing
102+ end
103+ loglikelihood =
104+ if include_log_probs && DynamicPPL. hasacc (new_vi, Val (:LogLikelihood ))
105+ DynamicPPL. getacc (new_vi, Val (:LogLikelihood )). logp
106+ else
107+ missing
108+ end
109+ return ModelEvaluationData (vn_leaves_dict, logprior, loglikelihood)
77110 end
78111 vn_leaves = collect (all_vn_leaves)
79112 vals = [
80- get (param_dicts[i, j], key, missing ) for i in eachindex (axes (param_dicts, 1 )),
81- key in vn_leaves, j in eachindex (axes (param_dicts, 2 ))
113+ get (evaluation_data[i, j]. values, key, missing ) for
114+ i in eachindex (axes (evaluation_data, 1 )), key in vn_leaves,
115+ j in eachindex (axes (evaluation_data, 2 ))
82116 ]
83117 symbols = map (Symbol, vn_leaves)
84118 info = (
85119 varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
86120 zip (all_vn_leaves, symbols)
87121 ),
88122 )
89- return MCMCChains. Chains (MCMCChains. concretize (vals), symbols; info= info)
123+ name_map = NamedTuple ()
124+ if include_log_probs
125+ logpriors = map (e -> e. logprior, evaluation_data)
126+ loglikelihoods = map (e -> e. loglikelihood, evaluation_data)
127+ logjoints = map (e -> e. logprior + e. loglikelihood, evaluation_data)
128+ lps = permutedims (stack ([logpriors, loglikelihoods, logjoints]), (1 , 3 , 2 ))
129+ vals = hcat (vals, lps)
130+ lp_symbols = [:logprior , :loglikelihood , :lp ]
131+ append! (symbols, lp_symbols)
132+ name_map = (internals= lp_symbols,)
133+ end
134+ return MCMCChains. Chains (MCMCChains. concretize (vals), symbols, name_map; info= info)
90135end
91136function DynamicPPL. varinfos_to_chains (
92137 :: Type{MCMCChains.Chains} ,
93- model:: DynamicPPL.Model ,
94138 varinfos:: AbstractVector{<:DynamicPPL.AbstractVarInfo} ,
139+ model:: Union{DynamicPPL.Model,Nothing} ;
95140 include_colon_eq:: Bool = true ,
141+ include_log_probs:: Bool = true ,
96142)
97143 return DynamicPPL. varinfos_to_chains (
98- MCMCChains. Chains, model, hcat (varinfos), include_colon_eq
144+ MCMCChains. Chains,
145+ hcat (varinfos),
146+ model;
147+ include_colon_eq= include_colon_eq,
148+ include_log_probs= include_log_probs,
99149 )
100150end
101151
0 commit comments