diff --git a/HISTORY.md b/HISTORY.md index 55010c533..309bbe011 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,45 @@ # DynamicPPL Changelog +## 0.38.0 + +**Breaking changes** + +foo + +**Other changes** + +### Thread-safe execution + +This release removes `ThreadSafeVarInfo`, which was a construction used to ensure thread-safe accumulation of log-likelihood terms using the `Threads.@threads`. +However, `Threads.@threads` is no longer the recommended way to perform multithreaded tasks: see e.g. [this Julia blog post](https://julialang.org/blog/2023/07/PSA-dont-use-threadid/). + +In its place a new macro, `@pobserve` is introduced, which under the hood uses `Threads.@spawn`. +**From a user's point of view you simply need to replace `Threads.@threads` with `@pobserve`.** +For example, here the likelihood contributions for each element of `y` are calculated in parallel: + +```julia +@model function f(y) + mu ~ Normal() + yplusones = @pobserve for i in eachindex(y) + y[i] ~ Normal(mu) + return y[i] + 1 + end +end +``` + +Furthermore, the `@pobserve` block will also return the final value inside the block, so can also be used to parallelise arbitrary computation. In the model above, `yplusones` will be a vector of length `y` where each element is `y[i] + 1`. + +Please note that this only works for **likelihood terms**, i.e., observed variables (hence the macro name). +It is a long-term goal to be able to parallelise other parts of model execution such as the sampling of new variables, but this is not presently possible. + +`@pobserve` is also not currently compatible with Turing's particle samplers (because Libtask.jl does not work with `Threads.@spawn)`. +This is, in fact, a good thing, because the previous behaviour of particle samplers with `Threads.@threads` was to silently give a wrong result. + +### Other minor changes + +The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical. + ## 0.37.3 Prevents inlining of `DynamicPPL.istrans` with Enzyme, which allows Enzyme to differentiate models where `VarName`s have the same symbol but different types. diff --git a/Project.toml b/Project.toml index 024aef5c3..f2e39b778 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.3" +version = "0.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 3d14d03ff..cd4545cb9 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -22,7 +22,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.37" +DynamicPPL = "0.38" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..124da3315 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -18,7 +18,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.37" +DynamicPPL = "0.38" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..f424c7836 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,6 +160,12 @@ It is possible to manually increase (or decrease) the accumulated log likelihood @addlogprob! ``` +If you want to perform observations in parallel (using Julia threads), you can use the following macro. + +```@docs +@pobserve +``` + Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`. ```@docs @@ -435,8 +441,6 @@ DynamicPPL.maybe_invlink_before_eval!! Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten -DynamicPPL.varname_leaves -DynamicPPL.varname_and_value_leaves ``` ### Evaluation Contexts @@ -449,11 +453,6 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. @@ -463,6 +462,32 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. + +```@docs +DynamicPPL.init!! +``` + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +InitFromPrior +InitFromUniform +InitFromParams +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers @@ -486,7 +511,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..55016d40c 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -6,7 +6,6 @@ using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) - # Let's make sure that both evaluation and sampling doesn't result in type errors. f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. @@ -21,32 +20,40 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + # Generate a typed varinfo to test model type stability with + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. - issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + # Check type stability of evaluation (i.e. DefaultContext) + model = DynamicPPL.contextualize( + model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()) + ) + eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl ) + if !eval_issuccess + @debug "Evaluation with typed varinfo failed with the following issues:" + @debug eval_result + end - if !issuccess - # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" - @debug result + # Check type stability of initialisation (i.e. InitContext) + model = DynamicPPL.contextualize( + model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) + ) + init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl + ) + if !init_issuccess + @debug "Initialisation with typed varinfo failed with the following issues:" + @debug init_result end - # If we didn't fail anywhere, we return the type stable one. - return if issuccess + # If neither of them failed, we can return the typed varinfo as it's type stable. + return if (eval_issuccess && init_issuccess) varinfo else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..7b9322254 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,12 +1,7 @@ module DynamicPPLMCMCChainsExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using MCMCChains: MCMCChains -else - using ..DynamicPPL: DynamicPPL - using ..MCMCChains: MCMCChains -end +using DynamicPPL: DynamicPPL, AbstractPPL +using MCMCChains: MCMCChains # Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata function DynamicPPL.loadstate(chain::MCMCChains.Chains) @@ -28,7 +23,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +37,17 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict( + c::MCMCChains.Chains{Tval}, sample_idx, chain_idx +) where {Tval} + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Tval}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -114,14 +120,20 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, vcat, - map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), ) return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) @@ -248,13 +260,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + retval end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5c8233915..22011251b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -108,6 +108,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + InitFromPrior, + InitFromUniform, + InitFromParams, # Pseudo distributions NamedDist, NoDist, @@ -127,6 +133,7 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @pobserve, value_iterator_from_chain, check_model, check_model_and_trace, @@ -169,21 +176,22 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") -include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") +include("pobserve_macro.jl") include("pointwise_logdensities.jl") include("transforming.jl") include("logdensityfunction.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..326850fdf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -135,7 +135,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, + vi::VarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..636847117 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,196 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). + +Any subtype of `AbstractInitStrategy` must implement the +[`DynamicPPL.init`](@ref) method. +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Return values must be unlinked" + The values returned by `init` must always be in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + InitFromPrior() + +Obtain new values by sampling from the prior distribution. +""" +struct InitFromPrior <: AbstractInitStrategy end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) + return rand(rng, dist) +end + +""" + InitFromUniform() + InitFromUniform(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. + +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. + +Requires that `lower <= upper`. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + InitFromUniform() = InitFromUniform(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + InitFromParams( + params::Union{AbstractDict{<:VarName},NamedTuple}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + +Obtain new values by extracting them from the given dictionary or NamedTuple. + +The parameter `fallback` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. `fallback` +can either be an initialisation strategy itself, in which case it will be +used to obtain new values, or it can be `nothing`, in which case an error +will be thrown. The default for `fallback` is `InitFromPrior()`. + +!!! note + The values in `params` must be provided in the space of the untransformed + distribution. +""" +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy + params::P + fallback::S + function InitFromParams( + params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} + ) + return new{typeof(params),typeof(fallback)}(params, fallback) + end + function InitFromParams(params::AbstractDict{<:VarName}) + return InitFromParams(params, InitFromPrior()) + end + function InitFromParams( + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + return InitFromParams(to_varname_dict(params), fallback) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + p.fallback === nothing && + error("A `missing` value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + else + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? + x + end + else + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=InitFromPrior()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + link_transform(dist) + else + identity + end + y, logjac = with_logabsdet_jacobian(f, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..19b88ec3f 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -425,8 +425,7 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index d311a5f63..8c7b5f7db 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -123,7 +123,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..22f8d5b21 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) @@ -815,43 +850,42 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) -end - -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 + return first(init!!(rng, model, varinfo)) end """ - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=InitFromPrior()] + ) -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added +using a specified initialisation strategy. -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). +If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function evaluate_and_sample!!( +function init!!( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() +function init!!( + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) + return init!!(Random.default_rng(), model, varinfo, init_strategy) end """ @@ -859,62 +893,19 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. - Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) - else - evaluate_threadunsafe!!(model, varinfo) - end -end - -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) return _evaluate!!(model, resetaccs!!(varinfo)) end -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) Evaluate the `model` with the given `varinfo`. -This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not -reset the log probability of the `varinfo` before running. +This function does not reset the accumulators in the `varinfo` before running. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) @@ -981,11 +972,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end @@ -1157,25 +1144,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) - -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches -the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. -""" -function predict( - rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} -) - varinfo = DynamicPPL.VarInfo(model) - return map(chain) do params_varinfo - vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) - return vi - end -end +# Implemented & documented in DynamicPPLMCMCChainsExt +function predict end """ returned(model::Model, parameters::NamedTuple) diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..e4c326b39 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -81,7 +81,7 @@ function varname_in_chain!( # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) @@ -107,7 +107,7 @@ function values_from_chain( # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) out = Accessors.set( diff --git a/src/pobserve_macro.jl b/src/pobserve_macro.jl new file mode 100644 index 000000000..b964e1de4 --- /dev/null +++ b/src/pobserve_macro.jl @@ -0,0 +1,90 @@ +using MacroTools: @capture, @q + +""" + @pobserve + +Perform observations in parallel. +""" +macro pobserve(expr) + return _pobserve(expr) +end + +function _pobserve(expr::Expr) + @capture( + expr, + for ctr_ in iterable_ + block_ + end + ) || error("expected for loop") + # reconstruct the for loop with the processed block + return_expr = @q begin + likelihood_tasks = map($(esc(iterable))) do $(esc(ctr)) + Threads.@spawn begin + $(process_tilde_statements(block)) + end + end + retvals_and_likelihoods = fetch.(likelihood_tasks) + total_likelihoods = sum(last, retvals_and_likelihoods) + if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)( + $(esc(:(__varinfo__))), total_likelihoods + ) + end + map(first, retvals_and_likelihoods) + end + return return_expr +end + +""" + process_tilde_statements(expr) + +This function traverses a block expression `expr` and transforms any +lines in it that look like `lhs ~ rhs` into a simple accumulation of +likelihoods, i.e., `Distributions.logpdf(rhs, lhs)`. +""" +function process_tilde_statements(expr::Expr) + @capture( + expr, + begin + statements__ + end + ) || error("expected block") + @gensym loglike + beginning_expr = :( + $loglike = if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood)) + zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))) + else + zero($(DynamicPPL.LogProbType)) + end + ) + n_statements = length(statements) + transformed_statements::Vector{Vector{Expr}} = map(enumerate(statements)) do (i, stmt) + is_last = i == n_statements + if @capture(stmt, lhs_ ~ rhs_) + # TODO: We should probably perform some checks to make sure that this + # indeed was meant to be an observe statement. + @gensym left + e = [ + :($left = $(esc(lhs))), + :($loglike += $(Distributions.logpdf)($(esc(rhs)), $left)), + ] + is_last && push!(e, :(($left, $loglike))) + e + elseif @capture(stmt, lhs_ .~ rhs_) + @gensym val + e = [ + # TODO: dot-tilde + :($val = $(esc(stmt))), + ] + is_last && push!(e, :(($val, $loglike))) + e + else + @gensym val + e = [:($val = $(esc(stmt)))] + is_last && push!(e, :(($val, $loglike))) + e + end + end + new_statements = [beginning_expr, reduce(vcat, transformed_statements)...] + return Expr(:block, new_statements...) +end diff --git a/src/sampler.jl b/src/sampler.jl index 27b990336..98b50ba55 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) +respectively. Additionally, sometimes one might want to implement an [`init_strategy`](@ref) that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior. """ @@ -58,8 +58,9 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) - return vi, nothing + strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() + _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) + return new_vi, nothing end """ @@ -67,6 +68,8 @@ end Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns a NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -75,11 +78,26 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSampler) + # Note that in `AbstractMCMC.step`, the values in the varinfo returned here are + # immediately overwritten by a subsequent call to `init!!`. The reason why we + # _do_ create a varinfo with parameters here (as opposed to simply returning + # an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty + # typed VarInfo would fail. This can happen if two VarNames have different types + # but share the same symbol (e.g. `x.a` and `x.b`). + # TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments + # and return an empty VarInfo instead. + return typed_varinfo(VarInfo(rng, model)) end +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. +""" +init_strategy(::Sampler) = InitFromPrior() + function AbstractMCMC.sample( rng::Random.AbstractRNG, model::Model, @@ -112,24 +130,24 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo. Note that any parameters inside this varinfo + # will be immediately overwritten by the next call to `init!!`. vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `InitFromParams` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -147,110 +165,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index cfad93ed9..bdf36a750 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -232,24 +232,27 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +268,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -408,12 +411,8 @@ function BangBang.push!!( return Accessors.@set vi.values = setindex!!(vi.values, value, vn) end -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - # Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V +Base.eltype(::SimpleVarInfo{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -471,7 +470,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::SimpleOrThreadSafeSimple, + vi::SimpleVarInfo, ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. @@ -482,21 +481,25 @@ function assume( return value, vi end -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. -function settrans!!(vi::SimpleVarInfo, trans) +function settrans!!(vi::SimpleVarInfo, trans::Bool) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Accessors.@set vi.transformation = transformation end -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) +function settrans!!(vi::SimpleVarInfo, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) diff --git a/src/test_utils.jl b/src/test_utils.jl index 65079f023..195345d60 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,7 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves, update_values!! +using DynamicPPL: update_values!! include("test_utils/model_interface.jl") include("test_utils/models.jl") diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..d53ba6c5f 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -25,25 +25,49 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - # `NodeTrait`. node_trait = DynamicPPL.NodeTrait(context) - # Throw error immediately if it it's missing a `NodeTrait` implementation. - node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) - - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) + elseif node_trait isa DynamicPPL.IsParent + test_parent_context(context, model) else - DefaultContext() + error("Invalid NodeTrait: $node_trait") end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. (For example, DefaultContext will error with empty + # varinfos.) Thus we only test evaluation with VarInfos that are already + # filled with values. + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Set the test context as the new leaf context + new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context)) + # Check that evaluation works + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent + + @testset "get/set leaf and child contexts" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/test_utils/sampler.jl b/src/test_utils/sampler.jl index 71cdb1cac..3ef965bad 100644 --- a/src/test_utils/sampler.jl +++ b/src/test_utils/sampler.jl @@ -51,7 +51,7 @@ function test_sampler( for vn in filter(varnames_filter, varnames(model)) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varname_leaves(vn, get(target_values, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..e3026ba6c 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -15,17 +15,13 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal end """ - setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false) + setup_varinfos(model::Model, example_values::NamedTuple, varnames) Return a tuple of instances for different implementations of `AbstractVarInfo` with each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. -If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions -of the varinfo instances. """ -function setup_varinfos( - model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false -) +function setup_varinfos(model::Model, example_values::NamedTuple, varnames) # VarInfo vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) @@ -51,9 +47,5 @@ function setup_varinfos( last(DynamicPPL.evaluate!!(model, vi)) end - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl deleted file mode 100644 index 6ca3b9852..000000000 --- a/src/threadsafe.jl +++ /dev/null @@ -1,236 +0,0 @@ -""" - ThreadSafeVarInfo - -A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of accumulators for thread-safe execution of a probabilistic model. -""" -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo - varinfo::V - accs_by_thread::Vector{L} -end -function ThreadSafeVarInfo(vi::AbstractVarInfo) - # In ThreadSafeVarInfo we use threadid() to index into the array of logp - # fields. This is not good practice --- see - # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full - # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] - return ThreadSafeVarInfo(vi, accs_by_thread) -end -ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi - -transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) - -# Set the accumulator in question in vi.varinfo, and set the thread-specific -# accumulators of the same type to be empty. -function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) - inner_vi = setacc!!(vi.varinfo, acc) - news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) - return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) -end - -# Get both the main accumulator and the thread-specific accumulators of the same type and -# combine them. -function getacc(vi::ThreadSafeVarInfo, accname::Val) - main_acc = getacc(vi.varinfo, accname) - other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) - return foldl(combine, other_accs; init=main_acc) -end - -hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) -acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) - -function getaccs(vi::ThreadSafeVarInfo) - # This method is a bit finicky to maintain type stability. For instance, moving the - # accname -> Val(accname) part in the main `map` call makes constant propagation fail - # and this becomes unstable. Do check the effects if you make edits. - accnames = acckeys(vi) - accname_vals = map(Val, accnames) - return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) -end - -# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that -# should _not_ be thread-specific a specific method has to be written. -function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) - return vi -end - -function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) - return vi -end - -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - -syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) - -setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) - -keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) -haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) - -islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) - -function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) -end - -function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) -end - -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) - return settrans!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) - return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - -function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) - # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the - # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogprior(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) -end - -# `getindex` -getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) -getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) -getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = getindex(vi.varinfo, vns) -function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) - return getindex(vi.varinfo, vn, dist) -end -function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) - return getindex(vi.varinfo, vns, dist) -end - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) -end - -vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) -vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) -function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) - return vector_getranges(vi.varinfo, vns) -end - -isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) -end - -function resetaccs!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) - for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i]) - end - return vi -end - -values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) -values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) - -function unset_flag!( - vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false -) - return unset_flag!(vi.varinfo, vn, flag, ignoreable) -end -function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return is_flagged(vi.varinfo, vn, flag) -end - -function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) -end - -istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) - -getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) - -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) -end - -function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) -end - -function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) - return Accessors.@set varinfo_left.varinfo = merge( - varinfo_left.varinfo, varinfo_right.varinfo - ) -end - -function invlink_with_logpdf(vi::ThreadSafeVarInfo, vn::VarName, dist, y) - return invlink_with_logpdf(vi.varinfo, vn, dist, y) -end - -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_internal_transform(varinfo.varinfo, vn) -end -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_internal_transform(varinfo.varinfo, vn, dist) -end - -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_linked_internal_transform(varinfo.varinfo, vn) -end -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_linked_internal_transform(varinfo.varinfo, vn, dist) -end diff --git a/src/utils.jl b/src/utils.jl index d3371271f..c7d1e089f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -837,245 +837,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -""" - varname_leaves(vn::VarName, val) - -Return an iterator over all varnames that are represented by `vn` on `val`. - -# Examples -```jldoctest -julia> using DynamicPPL: varname_leaves - -julia> foreach(println, varname_leaves(@varname(x), rand(2))) -x[1] -x[2] - -julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) -x[1:2][1] -x[1:2][2] - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_leaves(@varname(x), x)) -x.y -x.z[1][1] -x.z[2][1] -``` -""" -varname_leaves(vn::VarName, ::Real) = [vn] -function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for - I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] - ) for I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) - end - return Iterators.flatten(iter) -end - -""" - varname_and_value_leaves(vn::VarName, val) - -Return an iterator over all varname-value pairs that are represented by `vn` on `val`. - -# Examples -```jldoctest varname-and-value-leaves -julia> using DynamicPPL: varname_and_value_leaves - -julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) -(x[1], 1) -(x[2], 2) - -julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) -(x[1:2][1], 1) -(x[1:2][2], 2) - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(@varname(x), x)) -(x.y, 1) -(x.z[1][1], 2.0) -(x.z[2][1], 3.0) -``` - -There is also some special handling for certain types: - -```jldoctest varname-and-value-leaves -julia> using LinearAlgebra - -julia> x = reshape(1:4, 2, 2); - -julia> # `LowerTriangular` - foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1, 1], 1) -(x[2, 1], 2) -(x[2, 2], 4) - -julia> # `UpperTriangular` - foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1, 1], 1) -(x[1, 2], 3) -(x[2, 2], 4) - -julia> # `Cholesky` with lower-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1, 1], 1.0) -(x.L[2, 1], 0.0) -(x.L[2, 2], 1.0) - -julia> # `Cholesky` with upper-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1, 1], 1.0) -(x.U[1, 2], 0.0) -(x.U[2, 2], 1.0) -``` -""" -function varname_and_value_leaves(vn::VarName, x) - return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) -end - -""" - varname_and_value_leaves(container) - -Return an iterator over all varname-value pairs that are represented by `container`. - -This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container -containing multiple varnames. - -See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref). - -# Examples -```jldoctest varname-and-value-leaves-container -julia> using DynamicPPL: varname_and_value_leaves - -julia> # With an `OrderedDict` - dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(dict)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) - -julia> # With a `NamedTuple` - nt = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(nt)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) -``` -""" -function varname_and_value_leaves(container::OrderedDict) - return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container) -end -function varname_and_value_leaves(container::NamedTuple) - return Iterators.flatten( - varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container) - ) -end - -""" - Leaf{T} - -A container that represents the leaf of a nested structure, implementing -`iterate` to return itself. - -This is particularly useful in conjunction with `Iterators.flatten` to -prevent flattening of nested structures. -""" -struct Leaf{T} - value::T -end - -Leaf(xs...) = Leaf(xs) - -# Allow us to treat `Leaf` as an iterator containing a single element. -# Something like an `[x]` would also be an iterator with a single element, -# but when we call `flatten` on this, it would also iterate over `x`, -# unflattening that too. By making `Leaf` a single-element iterator, which -# returns itself, we can call `iterate` on this as many times as we like -# without causing any change. The result is that `Iterators.flatten` -# will _not_ unflatten `Leaf`s. -# Note that this is similar to how `Base.iterate` is implemented for `Real`:: -# -# julia> iterate(1) -# (1, nothing) -# -# One immediate example where this becomes in our scenario is that we might -# have `missing` values in our data, which does _not_ have an `iterate` -# implemented. Calling `Iterators.flatten` on this would cause an error. -Base.iterate(leaf::Leaf) = leaf, nothing -Base.iterate(::Leaf, _) = nothing - -# Convenience. -value(leaf::Leaf) = leaf.value - -# Leaf-types. -varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] -function varname_and_value_leaves_inner( - vn::VarName, val::AbstractArray{<:Union{Real,Missing}} -) - return ( - Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -# Containers. -function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_and_value_leaves_inner( - VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) - ) - end - - return Iterators.flatten(iter) -end -# Special types. -function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) - # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) - else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) - end -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the lower-triangular indices. - for I in CartesianIndices(x) if I[1] >= I[2] - ) -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the upper-triangular indices. - for I in CartesianIndices(x) if I[1] <= I[2] - ) -end - broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) diff --git a/src/varinfo.jl b/src/varinfo.jl index dec4db3ec..6311be9e0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,10 +113,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +133,14 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -146,9 +152,6 @@ const UntypedVarInfo = VarInfo{<:Metadata} # something which carried both its keys as well as its values' types as type # parameters. const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -195,7 +198,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +206,17 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -270,7 +275,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +283,21 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +305,27 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +333,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +345,16 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -364,6 +379,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # The below line is finicky for type stability. For instance, assigning the eltype to # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. + # TODO(penelopeysm): Can this be simplified if TSVI is gone? accs = map( acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) @@ -944,12 +960,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) -end - function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -957,17 +967,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model:: return vi end -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) -end - function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) @@ -1049,12 +1048,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1062,17 +1055,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, mode return vi end -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do @@ -1162,27 +1144,10 @@ function link(::DynamicTransformation, varinfo::VarInfo, model::Model) return _link(model, varinfo, keys(varinfo)) end -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) -end - function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _link(model, varinfo, vns) end -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1326,29 +1291,10 @@ function invlink(::DynamicTransformation, vi::VarInfo, model::Model) return _invlink(model, vi, keys(vi)) end -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _invlink(model, varinfo, vns) end -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end - function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1508,42 +1454,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -1832,7 +1742,7 @@ end Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) +function _apply!(kernel!, vi::VarInfo, values, keys) keys_strings = map(string, collect_maybe(keys)) num_indices_seen = 0 @@ -1890,7 +1800,7 @@ end end end -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) +function _find_missing_keys(vi::VarInfo, keys) string_vns = map(string, collect_maybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1955,119 +1865,12 @@ function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - end - - return indices -end - -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") end return indices diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new diff --git a/test/ad.jl b/test/ad.jl index 371e79b06..0e5d8d7cf 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -111,9 +111,10 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) + vi = DynamicPPL.link!!(VarInfo(model), model) sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal, vi; adtype=AutoReverseDiff(; compile=true) ) x = ldf.varinfo[:] @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..3c451b6b0 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,11 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. + # During the model evaluation, its leaf context is changed to an InitContext, so + # `model_` is not going to be equal to `model`. We can still check equality of `f` + # though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,20 +598,15 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -620,11 +615,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..1a6279bf4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -49,7 +50,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( @@ -92,7 +92,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # here to split up arrays which could potentially have some, # but not all, elements being `missing`. conditioned_vns = mapreduce( - p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + p -> AbstractPPL.varname_leaves(p.first, p.second), vcat, pairs(conditioned_values), ) @@ -103,7 +103,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -165,29 +165,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) @@ -431,4 +432,246 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + empty_varinfos = [ + ("untyped+metadata", VarInfo()), + ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), + ( + "typed+VNV", + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + ), + ("SVI+NamedTuple", SimpleVarInfo()), + ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end + + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end + + @testset "InitFromPrior" begin + test_generating_new_values(InitFromPrior()) + test_replacing_values(InitFromPrior()) + test_rng_respected(InitFromPrior()) + test_link_status_respected(InitFromPrior()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end + + @testset "InitFromUniform" begin + test_generating_new_values(InitFromUniform()) + test_replacing_values(InitFromUniform()) + test_rng_respected(InitFromUniform()) + test_link_status_respected(InitFromUniform()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end + + @testset "InitFromParams" begin + test_link_status_respected(InitFromParams((; a=1.0))) + test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0))) + + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict( + @varname(x) => my_x, @varname(y) => missing + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) + end + end + end + end + end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..692f53911 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -30,7 +30,7 @@ DynamicPPL.UntypedVarInfo # Evaluation works (and it would even do so in practice), but sampling - # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. + # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. @model function demo4() x ~ Bernoulli() if x @@ -62,33 +62,37 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed - # If the test failed, check why it didn't infer a typed varinfo + # If the test failed, check what the type stability problem was for + # the typed varinfo. This is mostly useful for debugging from test + # logs. if !is_typed + @info "Model `$(model.f)` is not type stable with typed varinfo." typed_vi = DynamicPPL.typed_varinfo(model) - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi + + @info "Evaluating with DefaultContext:" + model = DynamicPPL.contextualize( + model, + DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()), + ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f, argtypes) + + @info "Initialising with InitContext:" + model = DynamicPPL.contextualize( + model, + DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()), ) - JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo ) - JET.test_call(f_sample, argtypes_sample) + JET.test_call(f, argtypes) end end end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/model.jl b/test/model.jl index 81f84e548..ffc4c23fe 100644 --- a/test/model.jl +++ b/test/model.jl @@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) + vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym end @@ -142,37 +142,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin - @model function multiple_types(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = multiple_types(1) - chain = make_chain_from_prior(model, 10) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +301,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -347,7 +316,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Extract varnames and values. vns_and_vals_xs = map( - collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs + collect ∘ Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs ) vns = map(first, first(vns_and_vals_xs)) vals = map(vns_and_vals_xs) do vns_and_vals @@ -513,7 +482,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -559,7 +532,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) @@ -584,43 +559,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end end @testset "ProductNamedTupleDistribution sampling" begin diff --git a/test/model_utils.jl b/test/model_utils.jl index 720ae55aa..af695dbf2 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -6,11 +6,11 @@ chain = make_chain_from_prior(model, 10) for (i, d) in enumerate(value_iterator_from_chain(model, chain)) for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) + val = AbstractPPL.getvalue(d, vn) # Because value_iterator_from_chain groups varnames with # the same parent symbol, we have to ungroup them here - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) + for vn_leaf in AbstractPPL.varname_leaves(vn, val) + val_leaf = AbstractPPL.getvalue(d, vn_leaf) @test val_leaf == chain[i, Symbol(vn_leaf), 1] end end diff --git a/test/pobserve_macro.jl b/test/pobserve_macro.jl new file mode 100644 index 000000000..ff523c3b5 --- /dev/null +++ b/test/pobserve_macro.jl @@ -0,0 +1,74 @@ +module DynamicPPLPobserveMacroTests + +using DynamicPPL, Distributions, Test + +@testset verbose = true "pobserve_macro.jl" begin + @testset "loglikelihood is correctly accumulated" begin + @model function f(x) + @pobserve for i in eachindex(x) + x[i] ~ Normal() + end + end + x = randn(3) + expected_loglike = loglikelihood(Normal(), x) + vi = VarInfo(f(x)) + @test isapprox(DynamicPPL.getloglikelihood(vi), expected_loglike) + end + + @testset "doesn't error when varinfo has no likelihood acc" begin + @model function f(x) + @pobserve for i in eachindex(x) + x[i] ~ Normal() + end + end + x = randn(3) + vi = VarInfo() + vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),)) + @test DynamicPPL.evaluate!!(f(x), vi) isa Any + end + + @testset "return values are correct" begin + @testset "single expression at the end" begin + @model function f(x) + xplusone = @pobserve for i in eachindex(x) + x[i] ~ Normal() + x[i] + 1 + end + return xplusone + end + x = randn(3) + @test f(x)() == x .+ 1 + + @testset "calculations are not repeated" begin + # Make sure that the final expression inside pobserve is not evaluated + # multiple times. + counter = 0 + increment_and_return(y) = (counter += 1; y) + @model function g(x) + xs = @pobserve for i in eachindex(x) + x[i] ~ Normal() + increment_and_return(x[i]) + end + return xs + end + x = randn(3) + @test g(x)() == x + @test counter == length(x) + end + end + + @testset "tilde expression at the end" begin + @model function f(x) + xs = @pobserve for i in eachindex(x) + # This should behave as if it returns `x[i]` + x[i] ~ Normal() + end + return xs + end + x = randn(3) + @test f(x)() == x + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..ef417350f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") + include("pobserve_macro.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") @@ -70,7 +71,6 @@ include("test_util.jl") include("lkj.jl") include("contexts.jl") include("context_implementations.jl") - include("threadsafe.jl") include("debug_utils.jl") include("submodels.jl") include("bijector.jl") diff --git a/test/sampler.jl b/test/sampler.jl index 5eb0da057..c812de938 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,4 +1,17 @@ @testset "sampler.jl" begin + @testset "varnames with same symbol but different type" begin + struct S <: AbstractMCMC.AbstractSampler end + DynamicPPL.initialstep(rng, model, ::DynamicPPL.Sampler{S}, vi; kwargs...) = vi + @model function g() + y = (; a=1, b=2) + y.a ~ Normal() + return y.b ~ Normal() + end + model = g() + spl = DynamicPPL.Sampler(S()) + @test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any + end + @testset "initial_state and resume_from kwargs" begin # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our # overloaded method. @@ -126,8 +139,8 @@ @test length(chains) == N # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + # will be drawn from U[-2, 2] and its mean should be 0. + @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 @@ -170,8 +183,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = InitFromUniform() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == InitFromPrior() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 @@ -182,7 +195,7 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) + let inits = InitFromParams((; p=0.2)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -210,7 +223,7 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) + let inits = InitFromParams((; s=4, m=-1)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @@ -234,7 +247,7 @@ end # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) + for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1))) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -254,54 +267,6 @@ @test c[1].metadata.m.vals == [-1] end end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 526fce92c..01cbfc593 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -160,7 +160,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -228,9 +228,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -275,7 +275,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/test_util.jl b/test/test_util.jl index e04486760..ab2a80dc0 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -13,9 +13,6 @@ const gdemo_default = gdemo_d() Return string representing a short description of `vi`. """ -function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) - return "threadsafe($(short_varinfo_name(vi.varinfo)))" -end function short_varinfo_name(vi::DynamicPPL.NTVarInfo) return if DynamicPPL.has_varnamedvector(vi) "TypedVectorVarInfo" @@ -72,7 +69,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I # We have to use varname_and_value_leaves so that each parameter is a scalar dicts = map(varinfos) do t vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) tuples = mapreduce(collect, vcat, iters) # The following loop is a replacement for: # push!(varnames, map(first, tuples)...) @@ -87,8 +84,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/threadsafe.jl b/test/threadsafe.jl deleted file mode 100644 index 0421c89e2..000000000 --- a/test/threadsafe.jl +++ /dev/null @@ -1,116 +0,0 @@ -@testset "threadsafe.jl" begin - @testset "constructor" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) - - @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - # TODO: Add more tests of the public API - @testset "API" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - - lp = getlogjoint(vi) - @test getlogjoint(threadsafe_vi) == lp - - threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) - @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 - @test getlogjoint(vi) == lp - @test getlogjoint(threadsafe_vi) == lp + 42 - - threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi) - @test iszero(getlogjoint(threadsafe_vi)) - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - - threadsafe_vi = setlogprior!!(threadsafe_vi, 42) - @test getlogjoint(threadsafe_vi) == 42 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - - x = rand(10_000) - - @model function wthreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wothreads(x) - - vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo - @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl index ba7c17b34..745697315 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,5 +1,5 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, # since `keys(varinfo_merged)` only contains `VarName` with `identity`. # So we just check that the original keys are present. @@ -42,7 +42,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -325,7 +325,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -376,8 +376,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -402,57 +402,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -466,9 +415,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin @@ -488,7 +434,7 @@ end θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) for (n, v) in mapreduce(collect, vcat, iters) n = string(n) if Symbol(n) ∉ keys(chain) @@ -533,17 +479,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using InitFromUniform does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using InitFromUniform specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -607,7 +554,7 @@ end function test_linked_varinfo(model, vi) # vn and dist are taken from the containing scope - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test istrans(vi, vn) @@ -618,6 +565,11 @@ end @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) @@ -628,11 +580,6 @@ end vi = DynamicPPL.settrans!!(vi, true, vn) test_linked_varinfo(model, vi) - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi) - ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) @@ -653,9 +600,7 @@ end vns = DynamicPPL.TestUtils.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, example_values, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) @@ -698,11 +643,9 @@ end @testset "mutating=$mutating" for mutating in [false, true] value_true = DynamicPPL.TestUtils.rand_prior_true(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # # varinfo[@varname(x[4:5])] = [x[4],] @@ -775,14 +718,11 @@ end end model = demo(0.0) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, (; x=1.0), (@varname(x),); include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, (; x=1.0), (@varname(x),)) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # Skip the inconcrete `SimpleVarInfo` types, since checking for type # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} + if varinfo isa SimpleVarInfo{<:AbstractDict} continue end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) @@ -802,13 +742,9 @@ end vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, model(), vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, model(), vns) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) + varinfos_simple = filter(Base.Fix2(isa, DynamicPPL.SimpleVarInfo), varinfos) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -848,8 +784,7 @@ end # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, ## i.e. `VarName{sym}()` without any indexing, etc. vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple + if varinfo isa DynamicPPL.SimpleVarInfo && values_as(varinfo) isa NamedTuple vns_supported_simple else vns_supported_standard @@ -921,10 +856,7 @@ end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, - DynamicPPL.TestUtils.rand_prior_true(model), - vns; - include_threadsafe=true, + model, DynamicPPL.TestUtils.rand_prior_true(model), vns ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin @@ -1012,58 +944,15 @@ end @test merge(vi_double, vi_single)[vn] == 1.0 end - @testset "sampling from linked varinfo" begin - # `~` - @model function demo(n=1) - x = Vector(undef, n) - for i in eachindex(x) - x[i] ~ Exponential() - end - return x - end - model1 = demo(1) - varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x[1]), @varname(x[2])] - @test DynamicPPL.istrans(varinfo2, vn) - end - - # `.~` - @model function demo_dot(n=1) - x ~ Exponential() - if n > 1 - y = Vector(undef, n - 1) - y .~ Exponential() - end - return x - end - model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x), @varname(y[1])] - @test DynamicPPL.istrans(varinfo2, vn) - end - end - # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. @testset "vector_getranges for `VarInfo`" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, nt, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) # Only keep `VarInfo` types. - varinfos = filter( - Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos - ) + varinfos = filter(Base.Fix2(isa, DynamicPPL.VarInfo), varinfos) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos x = values_as(varinfo, Vector) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..de7a7c186 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -586,9 +586,7 @@ end value_true = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) # Filter out those which are not based on `VarNamedVector`. varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) # Get the true log joint. @@ -610,9 +608,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different.