@@ -176,13 +176,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
176
176
# this means that the code below will work both of linked and invlinked `vi`.
177
177
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
178
178
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
179
- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
180
-
181
- # Obtain an iterator over the flattened parameter names and values.
182
- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
183
-
184
- # Materialize the iterators and concatenate.
185
- return mapreduce (collect, vcat, iters)
179
+ return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
186
180
end
187
181
function getparams (
188
182
model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
@@ -193,14 +187,25 @@ function getparams(
193
187
return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
194
188
end
195
189
function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
196
- return float (Real)[]
190
+ return Dict {VarName,Any} ()
197
191
end
198
192
199
193
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
200
194
names_set = OrderedSet {VarName} ()
201
195
# Extract the parameter names and values from each transition.
202
196
dicts = map (ts) do t
203
- nms_and_vs = getparams (model, t)
197
+ # In general getparams returns a dict of VarName => values. We need to also
198
+ # split it up into constituent elements using
199
+ # `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
200
+ # won't understand it.
201
+ vals = getparams (model, t)
202
+ nms_and_vs = if isempty (vals)
203
+ Tuple{VarName,Any}[]
204
+ else
205
+ iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
206
+ mapreduce (collect, vcat, iters)
207
+ end
208
+
204
209
nms = map (first, nms_and_vs)
205
210
vs = map (last, nms_and_vs)
206
211
for nm in nms
@@ -210,9 +215,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
210
215
return OrderedDict (zip (nms, vs))
211
216
end
212
217
names = collect (names_set)
213
- vals = [
214
- get (dicts[i], key, missing ) for i in eachindex (dicts), (j, key) in enumerate (names)
215
- ]
218
+ vals = [get (dicts[i], key, missing ) for i in eachindex (dicts), key in names]
216
219
217
220
return names, vals
218
221
end
0 commit comments