diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..9f9a3720a 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -595,7 +595,7 @@ OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: m => 2.0 julia> values_as(vi, Vector) -2-element Vector{Real}: +2-element Vector{Float64}: 1.0 2.0 ``` diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..3a08b8896 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -90,7 +90,7 @@ the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both have the same symbol `x`. Several type aliases are provided for these forms of VarInfos: -- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:Metadata}` is `UntypedLegacyVarInfo` - `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` - `VarInfo{<:NamedTuple}` is `NTVarInfo` @@ -107,7 +107,7 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta accs::Accs end -function VarInfo(meta=Metadata()) +function VarInfo(meta=VarNamedVector()) return VarInfo(meta, default_accumulators()) end @@ -143,7 +143,7 @@ function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} -const UntypedVarInfo = VarInfo{<:Metadata} +const UntypedLegacyVarInfo = VarInfo{<:Metadata} # TODO: NTVarInfo carries no information about the type of the actual metadata # i.e. the elements of the NamedTuple. It could be Metadata or it could be # VarNamedVector. @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +const UntypedVarInfo = UntypedVectorVarInfo function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -194,8 +195,20 @@ end # VarInfo constructors # ######################## +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return untyped_vector_varinfo(rng, model, init_strategy) +end + +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) +end + """ - untyped_varinfo([rng, ]model[, init_strategy]) + untyped_legacy_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -205,19 +218,21 @@ Construct a VarInfo object for the given `model`, which has just a single - `model::Model`: The model for which to create the varinfo object - `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ -function untyped_varinfo( +function untyped_legacy_varinfo( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) +function untyped_legacy_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return untyped_legacy_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_varinfo(vi::UntypedVarInfo) + typed_legacy_varinfo(vi::UntypedLegacyVarInfo) This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the @@ -225,7 +240,7 @@ global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `meta a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol. """ -function typed_varinfo(vi::UntypedVarInfo) +function typed_legacy_varinfo(vi::UntypedLegacyVarInfo) meta = vi.metadata new_metas = Metadata[] # Symbols of all instances of `VarName{sym}` in `vi.vns` @@ -289,12 +304,16 @@ function typed_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) + return typed_vector_varinfo(rng, model, init_strategy) end function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return typed_varinfo(Random.default_rng(), model, init_strategy) end +function typed_varinfo(vi::UntypedVectorVarInfo) + return typed_vector_varinfo(vi) +end + """ untyped_vector_varinfo([rng, ]model[, init_strategy]) @@ -306,7 +325,7 @@ Return a VarInfo object for the given `model`, which has just a single - `model::Model`: The model for which to create the varinfo object - `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ -function untyped_vector_varinfo(vi::UntypedVarInfo) +function untyped_vector_varinfo(vi::UntypedLegacyVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end @@ -626,11 +645,11 @@ end const VarView = Union{Int,UnitRange,Vector{Int}} """ - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) + setval!(vi::UntypedLegacyVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) Set the value of `vi.vals[vview]` to `val`. """ -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val +setval!(vi::UntypedLegacyVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val """ getmetadata(vi::VarInfo, vn::VarName) @@ -825,10 +844,10 @@ set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, Returns a tuple of the unique symbols of random variables in `vi`. """ -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols +syms(vi::UntypedLegacyVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::NTVarInfo) = keys(vi.metadata) -_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::UntypedLegacyVarInfo) = 1:length(vi.metadata.idcs) _getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} @@ -949,7 +968,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedLegacyVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~is_transformed(vi, vns[1]) for vn in vns @@ -1063,7 +1082,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedLegacyVarInfo, vns) if is_transformed(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1477,7 +1496,7 @@ function _invlink_metadata!!( end # TODO(mhauru) The treatment of the case when some variables are transformed and others are -# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` +# not should be revised. It used to be the case that for UntypedLegacyVarInfo `is_transformed` # returned whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. @@ -1567,9 +1586,15 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) + function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - setindex!(vi, val, vn) - return vi + md = setindex!!(getmetadata(vi, vn), val, vn) + return VarInfo(md, vi.accs) +end + +function BangBang.setindex!!(vi::NTVarInfo, val, vn::VarName) + submd = setindex!!(getmetadata(vi, vn), val, vn) + return Accessors.@set vi.metadata[getsym(vn)] = submd end @inline function findvns(vi, f_vns) @@ -1594,7 +1619,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) return any(md_haskey) end -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) +function Base.show(io::IO, ::MIME"text/plain", vi::UntypedLegacyVarInfo) lines = Tuple{String,Any}[ ("VarNames", vi.metadata.vns), ("Range", vi.metadata.ranges), @@ -1649,7 +1674,7 @@ function _show_varnames(io::IO, vi) end end -function Base.show(io::IO, vi::UntypedVarInfo) +function Base.show(io::IO, vi::UntypedLegacyVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) print(io, "; accumulators: ") @@ -1821,11 +1846,11 @@ end values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) -function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) +function values_as(vi::UntypedLegacyVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end -function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} +function values_as(vi::UntypedLegacyVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) end diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..70e2fec86 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -417,12 +417,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "InitContext" begin empty_varinfos = [ - ("untyped+metadata", VarInfo()), - ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+metadata", VarInfo(DynamicPPL.Metadata())), + ( + "typed+metadata", + DynamicPPL.typed_legacy_varinfo(VarInfo(DynamicPPL.Metadata())), + ), ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), ( "typed+VNV", - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + DynamicPPL.typed_vector_varinfo(VarInfo(DynamicPPL.VarNamedVector())), ), ("SVI+NamedTuple", SimpleVarInfo()), ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..45c7415d6 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -76,8 +76,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess @@ -94,8 +97,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess @@ -112,8 +118,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index c74beefdb..00841b95e 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -40,7 +40,7 @@ end end @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo + DynamicPPL.NTVarInfo # In this model, the type error occurs in the user code rather than in DynamicPPL. @model function demo5() diff --git a/test/test_util.jl b/test/test_util.jl index 164751c7b..3a7ea0028 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -23,7 +23,7 @@ function short_varinfo_name(vi::DynamicPPL.NTVarInfo) "TypedVarInfo" end end -short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.UntypedLegacyVarInfo) = "UntypedLegacyVarInfo" short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" diff --git a/test/varinfo.jl b/test/varinfo.jl index a1a1b370f..0a4c9d447 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -37,8 +37,10 @@ end end model = gdemo(1.0, 2.0) - _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) - tvi = DynamicPPL.typed_varinfo(vi) + # TODO(mhauru) Make this test more generic. It currently explicitly relies on + # Metadata. + _, vi = DynamicPPL.init!!(model, VarInfo(DynamicPPL.Metadata()), InitFromUniform()) + tvi = DynamicPPL.typed_legacy_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -290,7 +292,7 @@ end dist = Normal(0, 1) r = rand(dist) - push!!(vi, vn_x, r, dist) + vi = push!!(vi, vn_x, r, dist) # is_transformed is set by default @test !is_transformed(vi, vn_x) @@ -353,7 +355,9 @@ end # worth specifically checking that it can do this without having to # change the VarInfo object. # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. - vi = VarInfo() + # TODO(mhauru) Make this test more generic. It currently explicitly relies on + # Metadata. + vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) @test all(x -> !is_transformed(vi, x), meta.vns) @@ -367,7 +371,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = DynamicPPL.typed_varinfo(vi) + vi = DynamicPPL.typed_legacy_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals)