@@ -202,18 +202,37 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo
202202 transformation:: C
203203end
204204
205- SimpleVarInfo (values, logp) = SimpleVarInfo (values, logp, NoTransformation ())
205+ # Makes things a bit more readable vs. putting `Float64` everywhere.
206+ const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64
206207
208+ function SimpleVarInfo {NT,T} (values, logp) where {NT,T}
209+ return SimpleVarInfo {NT,T,NoTransformation} (values, logp, NoTransformation ())
210+ end
207211function SimpleVarInfo {T} (θ) where {T<: Real }
208- return SimpleVarInfo (θ, zero (T))
212+ return SimpleVarInfo {typeof(θ),T} (θ, zero (T))
213+ end
214+
215+ # Constructors without type-specification.
216+ SimpleVarInfo (θ) = SimpleVarInfo {SIMPLEVARINFO_DEFAULT_ELTYPE} (θ)
217+ function SimpleVarInfo (θ:: Union{<:NamedTuple,<:AbstractDict} )
218+ return if isempty (θ)
219+ # Can't infer from values, so we just use default.
220+ SimpleVarInfo {SIMPLEVARINFO_DEFAULT_ELTYPE} (θ)
221+ else
222+ # Infer from `values`.
223+ SimpleVarInfo {float_type_with_fallback(infer_nested_eltype(typeof(θ)))} (θ)
224+ end
209225end
226+
227+ SimpleVarInfo (values, logp) = SimpleVarInfo {typeof(values),typeof(logp)} (values, logp)
228+
229+ # Using `kwargs` to specify the values.
210230function SimpleVarInfo {T} (; kwargs... ) where {T<: Real }
211231 return SimpleVarInfo {T} (NamedTuple (kwargs))
212232end
213233function SimpleVarInfo (; kwargs... )
214- return SimpleVarInfo {Float64} (NamedTuple (kwargs))
234+ return SimpleVarInfo (NamedTuple (kwargs))
215235end
216- SimpleVarInfo (θ) = SimpleVarInfo {Float64} (θ)
217236
218237# Constructor from `Model`.
219238SimpleVarInfo (model:: Model , args... ) = SimpleVarInfo {Float64} (model, args... )
@@ -582,3 +601,12 @@ julia> # Truth.
582601```
583602"""
584603Distributions. loglikelihood (model:: Model , θ) = loglikelihood (model, SimpleVarInfo (θ))
604+
605+ # Threadsafe stuff.
606+ # For `SimpleVarInfo` we don't really need `Ref` so let's not use it.
607+ function ThreadSafeVarInfo (vi:: SimpleVarInfo )
608+ return ThreadSafeVarInfo (vi, zeros (typeof (getlogp (vi)), Threads. nthreads ()))
609+ end
610+ function ThreadSafeVarInfo (vi:: SimpleVarInfo{<:Any,<:Ref} )
611+ return ThreadSafeVarInfo (vi, [Ref (zero (getlogp (vi))) for _ in 1 : Threads. nthreads ()])
612+ end
0 commit comments