diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..e78bf602f 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,14 +59,14 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 225e40cd8..e6988d3f2 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,9 +94,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend - ) + f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..29a4e2cc7 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL.get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..e49b81cb2 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,10 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed +using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..77d527ced 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -191,6 +191,7 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") +include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/fastldf.jl b/src/fastldf.jl new file mode 100644 index 000000000..eaca0c795 --- /dev/null +++ b/src/fastldf.jl @@ -0,0 +1,444 @@ +""" +fasteval.jl +----------- + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only +contains accumulators. When evaluating a model with `OnlyAccsVarInfo`, it is mandatory that +the model's leaf context is a `FastEvalContext`, which provides extremely fast access to +parameter values. No writing of values into VarInfo metadata is performed at all. + +Vector parameters +----------------- + +We first consider the case of parameter vectors, i.e., the case which would normally be +handled by `unflatten` and `evaluate!!`. Unfortunately, it is not enough to just store +the vector of parameters in the `FastEvalContext`, because it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. + +NamedTuple and Dict parameters +------------------------------ + +Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Such +representations are capable of handling models with variable sizes and stochastic control +flow. + +However, the path towards implementing these is straightforward: + +1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter + value, plus a boolean indicating whether the value is linked or unlinked. See the + `get_range_and_linked` function for details. + +2. We would need to implement similar contexts for NamedTuple and Dict parameters. The + functionality would be quite similar to `InitContext(InitFromParams(...))`. +""" + +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this outside of FastLDF will lead to errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +That is because values for random variables are obtained by reading from a separate entity +(such as a `FastLDFContext`), rather than from the VarInfo itself. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +function DynamicPPL.get_param_eltype( + ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model +) + # Because the VarInfo has no parameters stored in it, we need to get the eltype from the + # model's leaf context. This is only possible if said leaf context is indeed a FastEval + # context. + leaf_ctx = DynamicPPL.leafcontext(model.context) + if leaf_ctx isa FastEvalVectorContext + return eltype(leaf_ctx.params) + else + error( + "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", + ) + end +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + AbstractFastEvalContext + +Abstract type representing fast evaluation contexts. This currently is only subtyped by +`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for +NamedTuple and Dict parameters. +""" +abstract type AbstractFastEvalContext <: AbstractContext end +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() + +""" + FastEvalVectorContext( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + params::AbstractVector{<:Real}, + ) + +A context that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to unify the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: + AbstractFastEvalContext + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + params::T +end +function get_range_and_linked( + ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} +) where {sym} + return ctx.iden_varname_ranges[sym] +end +function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) + return ctx.varname_ranges[vn] +end + +function tilde_assume!!( + ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Note that this function does not use the metadata field of `vi` at all. + range_and_linked = get_range_and_linked(ctx, vn) + y = @view ctx.params[range_and_linked.range] + f = if range_and_linked.is_linked + from_linked_vec_transform(right) + else + from_vec_transform(right) + end + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +function tilde_observe!!( + ::FastEvalVectorContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + # This is the same as for DefaultContext + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end + +######################################## +# Log-density functions using FastEval # +######################################## + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +## Extended help + +`FastLDF` uses `FastEvalVectorContext` internally to provide extremely rapid evaluation of +the model given a vector of parameters. + +Because it is common to call `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient` within tight loops, it is beneficial for us to +pre-compute as much of the information as possible when constructing the `FastLDF` object. +In particular, we use the provided VarInfo's metadata to extract the mapping from VarNames +to ranges and link status, and store this mapping inside the `FastLDF` object. We can later +use this to construct a FastEvalVectorContext, without having to look into a metadata again. +""" +struct FastLDF{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + # See FastLDFContext for explanation of these two fields. + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + + function FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + _model::M + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) + model = DynamicPPL.setleafcontext(f._model, ctx) + accs = fast_ldf_accs(f._getlogdensity) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + vi = if Threads.nthreads() > 1 + accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + _, vi = _evaluate!!(model, vi) + return f._getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + ), + fldf._adprep, + fldf.adtype, + params, + ) +end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable. + +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/model.jl b/src/model.jl index edb042ba9..6ca06aea6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)...) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1006,22 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, model::Model) + +Get the element type of the parameters being used to evaluate the `model` from the +`varinfo`. For example, when performing AD with ForwardDiff, this should return +`ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +See `OnlyAccsVarInfo` for an example of where this is not true (the parameters are instead +stored in the model's context). +""" +get_param_eltype(varinfo::AbstractVarInfo, ::Model) = eltype(varinfo) + """ getargnames(model::Model) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..fbbae85b7 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,8 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -265,7 +264,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) + ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -282,9 +281,7 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = LogDensityFunction( - model, getlogdensity, varinfo; adtype=test.adtype - ) + ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..0d77c70e6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -394,7 +394,9 @@ end for f in names mdf = :(metadata.$f) len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + push!( + exprs, :($f = unflatten_metadata($mdf, @view x[($offset + 1):($offset + $len)])) + ) offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) @@ -751,11 +753,10 @@ function getdist(::VarNamedVector, ::VarName) end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) +function getindex_internal(md::Metadata, vn::VarName) + rng = getrange(md, vn) + return @view md.vals[rng] +end function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) end @@ -1495,8 +1496,21 @@ space. If some but only some of the variables in `vi` are transformed, this function will return `true`. This behavior will likely change in the future. """ -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) +function is_transformed(vi::NTVarInfo) + return is_transformed(vi.metadata) +end + +@generated function is_transformed(nt::NamedTuple{names}) where {names} + expr = Expr(:block) + push!(expr.args, :(result = false)) + for n in names + push!(expr.args, :(result = result || is_transformed(nt.$n))) + end + return expr +end + +function is_transformed(md::Metadata) + return any(md.is_transformed) end # The default getindex & setindex!() for get & set values @@ -1552,7 +1566,7 @@ end @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) + push!(expr.args, :(@view metadata.$f.vals[ranges.$f])) end return expr end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..a81f33ea5 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -367,6 +367,12 @@ Return a boolean for whether `vn` is guaranteed to have been transformed so that is all of Euclidean space. """ is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +""" + is_transformed(vnv::VarNamedVector) + +Return true if any variable in `vnv` is guaranteed to have been transformed. +""" +is_transformed(vnv::VarNamedVector) = any(vnv.is_unconstrained) """ set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) diff --git a/test/ad.jl b/test/ad.jl index d7505aab2..48b1b64ec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,6 @@ -using DynamicPPL: LogDensityFunction +using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -15,64 +16,25 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end + varinfo = VarInfo(Xoshiro(468), m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = linked_varinfo[:] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any end end end @@ -83,7 +45,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) + ldf = FastLDF(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES