@@ -9,7 +9,7 @@ struct DefaultTransformation <: AbstractTransformation end
99A simple wrapper of the parameters with a `logp` field for
1010accumulation of the logdensity.
1111
12- Currently only implemented for `NT<:NamedTuple` and `NT<:Dict `.
12+ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict `.
1313
1414# Fields
1515$(FIELDS)
@@ -69,8 +69,8 @@ julia> # (×) If we don't provide the container...
6969ERROR: type NamedTuple has no field x
7070[...]
7171
72- julia> # If one does not know the varnames, we can use a `Dict ` instead.
73- _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict ()), ctx);
72+ julia> # If one does not know the varnames, we can use a `OrderedDict ` instead.
73+ _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict ()), ctx);
7474
7575julia> # (✓) Sort of fast, but only possible at runtime.
7676 vi[@varname(x[1])]
@@ -86,6 +86,11 @@ ERROR: KeyError: key x[1:2] not found
8686[...]
8787```
8888
89+ _Technically_, it's possible to use any implementation of `AbstractDict` in place of
90+ `OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening
91+ of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is
92+ the preferred implementation of `AbstractDict` to use here.
93+
8994You can also sample in _transformed_ space:
9095
9196```jldoctest simplevarinfo-general
@@ -109,8 +114,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo()
109114julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
110115true
111116
112- julia> # And with `Dict ` of course!
113- _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict ()), true), ctx);
117+ julia> # And with `OrderedDict ` of course!
118+ _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict ()), true), ctx);
114119
115120julia> vi[@varname(x)] # (✓) -∞ < x < ∞
1161210.6225185067787314
@@ -165,9 +170,9 @@ ERROR: type NamedTuple has no field b
165170[...]
166171```
167172
168- Using `Dict ` as underlying storage.
173+ Using `OrderedDict ` as underlying storage.
169174```jldoctest
170- julia> svi_dict = SimpleVarInfo(Dict (@varname(m) => (a = [1.0], )));
175+ julia> svi_dict = SimpleVarInfo(OrderedDict (@varname(m) => (a = [1.0], )));
171176
172177julia> svi_dict[@varname(m)]
173178(a = [1.0],)
274279
275280Base. getindex (vi:: SimpleVarInfo , vn:: VarName ) = get (vi. values, vn)
276281
277- # `Dict `
282+ # `AbstractDict `
278283function Base. getindex (vi:: SimpleVarInfo{<:AbstractDict} , vn:: VarName )
279284 return nested_getindex (vi. values, vn)
280285end
@@ -364,7 +369,7 @@ function BangBang.push!!(
364369 return Setfield. @set vi. values = set!! (vi. values, vn, value)
365370end
366371
367- # `Dict `
372+ # `AbstractDict `
368373function BangBang. push!! (
369374 vi:: SimpleVarInfo{<:AbstractDict} ,
370375 vn:: VarName ,
@@ -473,17 +478,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
473478istrans (vi:: SimpleVarInfo , vn:: VarName ) = istrans (vi)
474479istrans (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , vn:: VarName ) = istrans (vi. varinfo, vn)
475480
476- """
477- values_as(varinfo[, Type])
478-
479- Return the values/realizations in `varinfo` as `Type`, if implemented.
480-
481- If no `Type` is provided, return values as stored in `varinfo`.
482- """
483481values_as (vi:: SimpleVarInfo ) = vi. values
484- values_as (vi:: SimpleVarInfo , :: Type{Dict} ) = Dict (pairs (vi. values))
485- values_as (vi:: SimpleVarInfo , :: Type{NamedTuple} ) = NamedTuple (pairs (vi. values))
486- values_as (vi:: SimpleVarInfo{<:NamedTuple} , :: Type{NamedTuple} ) = vi. values
482+ values_as (vi:: SimpleVarInfo{<:T} , :: Type{T} ) where {T} = vi. values
483+ function values_as (vi:: SimpleVarInfo , :: Type{D} ) where {D<: AbstractDict }
484+ return ConstructionBase. constructorof (D)(zip (keys (vi), values (vi. values)))
485+ end
486+ function values_as (vi:: SimpleVarInfo{<:AbstractDict} , :: Type{NamedTuple} )
487+ return NamedTuple ((Symbol (k), v) for (k, v) in vi. values)
488+ end
487489
488490"""
489491 logjoint(model::Model, θ)
0 commit comments