Skip to content

Commit 535ce4f

Browse files
penelopeysmmhauru
andauthored
FastLDF / InitContext unified (#1132)
* Fast Log Density Function * Make it work with AD * Optimise performance for identity VarNames * Mark `get_range_and_linked` as having zero derivative * Update comment * make AD testing / benchmarking use FastLDF * Fix tests * Optimise away `make_evaluate_args_and_kwargs` * const func annotation * Disable benchmarks on non-typed-Metadata-VarInfo * Fix `_evaluate!!` correctly to handle submodels * Actually fix submodel evaluate * Document thoroughly and organise code * Support more VarInfos, make it thread-safe (?) * fix bug in parsing ranges from metadata/VNV * Fix get_param_eltype for TSVI * Disable Enzyme benchmark * Don't override _evaluate!!, that breaks ForwardDiff (sometimes) * Move FastLDF to experimental for now * Fix imports, add tests, etc * More test fixes * Fix imports / tests * Remove AbstractFastEvalContext * Changelog and patch bump * Add correctness tests, fix imports * Concretise parameter vector in tests * Add zero-allocation tests * Add Chairmarks as test dep * Disable allocations tests on multi-threaded * Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation * Refactor FastLDF to use InitContext * note init breaking change * fix logjac sign * workaround Mooncake segfault * fix changelog too * Fix get_param_eltype for context stacks * Add a test for threaded observe * Export init * Remove dead code * fix transforms for pathological distributions * Tidy up loads of things * fix typed_identity spelling * fix definition order * Improve docstrings * Remove stray comment * export get_param_eltype (unfortunatley) * Add more comment * Update comment * Remove inlines, fix OAVI docstring * Improve docstrings * Simplify InitFromParams constructor * Replace map(identity, x[:]) with [i for i in x[:]] * Simplify implementation for InitContext/OAVI * Add another model to allocation tests Co-authored-by: Markus Hauru <[email protected]> * Revert removal of dist argument (oops) * Format * Update some outdated bits of FastLDF docstring * remove underscores --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent 4ca9528 commit 535ce4f

File tree

16 files changed

+955
-74
lines changed

16 files changed

+955
-74
lines changed

HISTORY.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ You should not need to use these directly, please use `AbstractPPL.condition` an
2121

2222
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
2323

24+
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
25+
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
26+
27+
### Other changes
28+
29+
#### FastLDF
30+
31+
Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
32+
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
33+
34+
Please note that `FastLDF` is currently considered internal and its API may change without warning.
35+
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.
36+
37+
For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
38+
2439
## 0.38.9
2540

2641
Remove warning when using Enzyme as the AD backend.

docs/src/api.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ DynamicPPL.prefix
170170

171171
## Utilities
172172

173+
`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors.
174+
175+
```@docs
176+
typed_identity
177+
```
178+
173179
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
174180

175181
```@docs
@@ -517,10 +523,12 @@ InitFromParams
517523
```
518524

519525
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.
526+
In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy.
520527

521528
```@docs
522-
DynamicPPL.AbstractInitStrategy
523-
DynamicPPL.init
529+
AbstractInitStrategy
530+
init
531+
get_param_eltype
524532
```
525533

526534
### Choosing a suitable VarInfo

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
module DynamicPPLEnzymeCoreExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: DynamicPPL
5-
using EnzymeCore
6-
else
7-
using ..DynamicPPL: DynamicPPL
8-
using ..EnzymeCore
9-
end
3+
using DynamicPPL: DynamicPPL
4+
using EnzymeCore
105

116
# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
127
# only checks whether such a method exists, and never runs it.
138
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
149
nothing
10+
# Likewise for get_range_and_linked.
11+
@inline EnzymeCore.EnzymeRules.inactive(
12+
::typeof(DynamicPPL._get_range_and_linked), args...
13+
) = nothing
1514

1615
end

ext/DynamicPPLMooncakeExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,8 @@ using Mooncake: Mooncake
55

66
# This is purely an optimisation.
77
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
8+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{
9+
typeof(DynamicPPL._get_range_and_linked),Vararg
10+
}
811

912
end # module

src/DynamicPPL.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ export AbstractVarInfo,
8484
# Compiler
8585
@model,
8686
# Utilities
87-
init,
8887
OrderedDict,
88+
typed_identity,
8989
# Model
9090
Model,
9191
getmissings,
@@ -113,6 +113,8 @@ export AbstractVarInfo,
113113
InitFromPrior,
114114
InitFromUniform,
115115
InitFromParams,
116+
init,
117+
get_param_eltype,
116118
# Pseudo distributions
117119
NamedDist,
118120
NoDist,
@@ -193,6 +195,7 @@ include("abstract_varinfo.jl")
193195
include("threadsafe.jl")
194196
include("varinfo.jl")
195197
include("simple_varinfo.jl")
198+
include("onlyaccs.jl")
196199
include("compiler.jl")
197200
include("pointwise_logdensities.jl")
198201
include("logdensityfunction.jl")

src/compiler.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -718,14 +718,15 @@ end
718718
# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
719719
# TODO(mhauru) This function needs a more comprehensive docstring.
720720
"""
721-
matchingvalue(vi, value)
721+
matchingvalue(param_eltype, value)
722722
723-
Convert the `value` to the correct type for the `vi` object.
723+
Convert the `value` to the correct type, given the element type of the parameters
724+
being used to evaluate the model.
724725
"""
725-
function matchingvalue(vi, value)
726+
function matchingvalue(param_eltype, value)
726727
T = typeof(value)
727728
if hasmissing(T)
728-
_value = convert(get_matching_type(vi, T), value)
729+
_value = convert(get_matching_type(param_eltype, T), value)
729730
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
730731
# are happy to return `value` as-is?
731732
if _value === value
@@ -738,29 +739,30 @@ function matchingvalue(vi, value)
738739
end
739740
end
740741

741-
function matchingvalue(vi, value::FloatOrArrayType)
742-
return get_matching_type(vi, value)
742+
function matchingvalue(param_eltype, value::FloatOrArrayType)
743+
return get_matching_type(param_eltype, value)
743744
end
744-
function matchingvalue(vi, ::TypeWrap{T}) where {T}
745-
return TypeWrap{get_matching_type(vi, T)}()
745+
function matchingvalue(param_eltype, ::TypeWrap{T}) where {T}
746+
return TypeWrap{get_matching_type(param_eltype, T)}()
746747
end
747748

748749
# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
749750
"""
750-
get_matching_type(vi, ::TypeWrap{T}) where {T}
751+
get_matching_type(param_eltype, ::TypeWrap{T}) where {T}
751752
752-
Get the specialized version of type `T` for `vi`.
753+
Get the specialized version of type `T`, given an element type of the parameters
754+
being used to evaluate the model.
753755
"""
754756
get_matching_type(_, ::Type{T}) where {T} = T
755-
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
756-
return Union{Missing,float_type_with_fallback(eltype(vi))}
757+
function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}})
758+
return Union{Missing,float_type_with_fallback(param_eltype)}
757759
end
758-
function get_matching_type(vi, ::Type{<:AbstractFloat})
759-
return float_type_with_fallback(eltype(vi))
760+
function get_matching_type(param_eltype, ::Type{<:AbstractFloat})
761+
return float_type_with_fallback(param_eltype)
760762
end
761-
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N}
762-
return Array{get_matching_type(vi, T),N}
763+
function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N}
764+
return Array{get_matching_type(param_eltype, T),N}
763765
end
764-
function get_matching_type(vi, ::Type{<:Array{T}}) where {T}
765-
return Array{get_matching_type(vi, T)}
766+
function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T}
767+
return Array{get_matching_type(param_eltype, T)}
766768
end

0 commit comments

Comments
 (0)