Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
KernelSpectralDensities = "027d52a2-76e5-4228-9bfe-bc7e0f5a8348"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this into a package extension?

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -24,6 +25,7 @@ Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
IrrationalConstants = "0.1, 0.2"
KernelFunctions = "0.9, 0.10"
KernelSpectralDensities = "0.2.0"
LinearAlgebra = "1"
PDMats = "0.11"
Random = "1"
Expand Down
5 changes: 5 additions & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using RecipesBase
using IrrationalConstants: log2π

using KernelFunctions: ColVecs, RowVecs
using KernelSpectralDensities

using ChainRulesCore: ChainRulesCore

Expand All @@ -33,6 +34,7 @@ export rand!,
posterior,
update_posterior
export ColVecs, RowVecs
export GPSampler, CholeskySampling, Conditional, Independent, RFFSampling, PathwiseSampling

# Various bits of utility functionality.
include("util/common_covmat_ops.jl")
Expand All @@ -56,6 +58,9 @@ include("sparse_approximations.jl")
# LatentGP and LatentFiniteGP objects to accommodate GPs with non-Gaussian likelihoods.
include("latent_gp.jl")

# Different sampling methods
include("sampling.jl")

# Plotting utilities.
include("util/plotting.jl")

Expand Down
264 changes: 264 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
abstract type AbstractGPSamplingMethod end

SeedableRNG = Union{Xoshiro,MersenneTwister}

_rand(rng, d) = Random.rand(rng, d)
function _rand(rng::AbstractRNG, ::Type{T}) where {T<:SeedableRNG}
return T(Random.rand(rng, 1:typemax(Int)))
end

# ## Interface

struct GPSample{F,S}
fun::F
sample::S
end

(gs::GPSample)(x::AbstractArray) = eval_at(gs.fun, gs.sample, x)

# This may become more challenging once we extend to multi-input GPS
(gs::GPSample)(x::Number) = only(gs([x]))

"""
GPSampler(gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod)
Creates a sampler for the given `gp` using the specified `method`.

```jldoctest
julia> f = GP(Matern32Kernel());

julia> gps = GPSampler(f, CholeskySampling());

julia> rand(gps);
```
"""
struct GPSampler{F,S} <: Random.Sampler{GPSample}
fun::F
sampler::S

# Specify input types here, since it is a "public" interface
function GPSampler(gp::AbstractGPs.AbstractGP, method::AbstractGPSamplingMethod)
fun, sampler = method(gp)
return new{typeof(fun),typeof(sampler)}(fun, sampler)
end
end

# Don't love the deepcopy here
# issue is "pass by sharing" and the mutable struct in CholeskySampling
function Random.rand(rng::AbstractRNG, gs::GPSampler)
return GPSample(deepcopy(gs.fun), _rand(rng, gs.sampler))
end

# ## Utils

_get_prior(gp::AbstractGPs.GP) = gp
_get_prior(pgp::AbstractGPs.PosteriorGP) = pgp.prior

function get_obs_variance(pgp::AbstractGPs.PosteriorGP)
σk = pgp.prior.kernel(0, 0)
v = diag(pgp.data.C.L * pgp.data.C.U) .- σk
return v
end

function get_target_prior(pgp::AbstractGPs.PosteriorGP)
m = pgp.data.δ
σ2 = get_obs_variance(pgp)
return MvNormal(m, sqrt.(σ2))
end

#########################
# Function Space/ Cholesky

"""
CholeskySampling(s=Conditional, generator=Xoshiro)
Sampling by using the standard way, via Cholesky decomposition.
Arguments:
- `s`: Sampling type, either `Conditional` or `Independent`. Default is `Conditional`.
- `generator`: Random number generator to use in each sample. Default is `Xoshiro`.
"""
struct CholeskySampling{M,G} <: AbstractGPSamplingMethod
function CholeskySampling(s=Conditional, generator=Xoshiro)
return new{s,generator}()
end
end

function (cs::CholeskySampling{M,G})(gp) where {M,G}
return M(gp), G
end

"""
Conditional
Generates a GP sample that conditions function samples on all previous samples.
"""
mutable struct Conditional{GPT<:AbstractGPs.AbstractGP}
gp::GPT
end

function Conditional(gp::AbstractGPs.GP)
data = (
α=Vector{Float64}(undef, 0),
C=Cholesky(UpperTriangular(Matrix{Float64}(undef, 0, 0))),
x=Vector{Float64}(undef, 0),
δ=Vector{Float64}(undef, 0),
)
pgp = AbstractGPs.PosteriorGP(gp, data)
return Conditional(pgp)
end

function eval_at(s::Conditional, rng, x::AbstractArray)
if isassigned(s.gp.data.x, 1)
pgp = s.gp
else
pgp = s.gp.prior
end
fgp = pgp(x)
y = rand(rng, fgp)
s.gp = posterior(fgp, y)
return y
end

"""
Independent
Generates a GP sample that samples function samples independent from previous samples.
"""
struct Independent{GPT<:AbstractGPs.AbstractGP}
gp::GPT
function Independent(gp)
return new{typeof(gp)}(gp)
end
end

function eval_at(s::Independent, rng, x::AbstractArray)
gp = s.gp
fgp = gp(x)
y = rand(rng, fgp)
return y
end

# ## WeightSpace

# ### Utils

get_weight_distribution(::AbstractGPs.GP, rff) = MvNormal(ones(rff.l))

function get_weight_distribution(pgp::AbstractGPs.PosteriorGP, rff)
d = get_target_prior(pgp)

P = rff.(pgp.data.x)
Pt = reduce(hcat, P)
Cp = Symmetric(Pt * (d.Σ \ Pt') + I)
C = cholesky(Cp)

μ = C \ (Pt * (d.Σ \ d.μ))
Σ = C \ I
return MvNormal(μ, Symmetric(Σ))
end

# ### Main

"""
RFFSampling(l::Int, rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF)
Sampling by using Random Fourier Features.
Arguments:
- `l`: Number of random Fourier features to use.
- `rff_type`: Type of random Fourier features to use. Default is `DoubleRFF`.
"""
struct RFFSampling{RFF,RNG} <: AbstractGPSamplingMethod
l::Int
rng::RNG
function RFFSampling(
rng, l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF
)
return new{rff_type,typeof(rng)}(l, rng)
end
end

function RFFSampling(l; rff_type::Type{<:KernelSpectralDensities.AbstractRFF}=DoubleRFF)
return RFFSampling(Random.default_rng(), l; rff_type)
end

function (rffs::RFFSampling{RFF})(gp) where {RFF}
prior = _get_prior(gp)
# for now, hardcoding "1", later expand for multi-input
S = SpectralDensity(prior.kernel, 1)
# ToDo: add rng to RFF
rff = RFF(rffs.rng, S, rffs.l)

ws = get_weight_distribution(gp, rff)

return rff, ws
end

function eval_at(rff::KernelSpectralDensities.AbstractRFF, w, x::AbstractArray)
return dot.(rff.(x), Ref(w))
end

# ## PathwiseSampler

# ### utils
struct KernelBasis{K}
ker::K
x::AbstractArray
end

(kb::KernelBasis)(x) = kb.ker.(Ref(x), kb.x)

function update_basis(pgp, cs::CholeskySampling)
ker = pgp.prior.kernel
x = pgp.data.x
return KernelBasis(ker, x)
end

function update_basis(pgp, rffs::RFFSampling)
rff, _ = rffs(pgp)

Check warning on line 212 in src/sampling.jl

View check run for this annotation

Codecov / codecov/patch

src/sampling.jl#L211-L212

Added lines #L211 - L212 were not covered by tests

return rff

Check warning on line 214 in src/sampling.jl

View check run for this annotation

Codecov / codecov/patch

src/sampling.jl#L214

Added line #L214 was not covered by tests
end

# ### Main

"""
PathwiseSampling(l::Int)
Sampling by using pathwise sampling, which uses RFF sampling for the prior and an update rule
based on the kernel. Takes as an input the number of random Fourier features `l` to use.
"""
struct PathwiseSampling{P,U} <: AbstractGPSamplingMethod
prior::P
update::U
end

function PathwiseSampling(l::Int)
return PathwiseSampling(RFFSampling(l), CholeskySampling())
end

struct PathwiseSampler{PS,TS,D}
prior_sampler::PS
target_sampler::TS
data::D
end

function (ps::PathwiseSampling)(pgp::AbstractGPs.PosteriorGP)
upd_fun = update_basis(pgp, ps.update)

prior = pgp.prior
prior_sampler = GPSampler(prior, ps.prior)

target_dist = get_target_prior(pgp)

data = (C=pgp.data.C, x=pgp.data.x)
return upd_fun, PathwiseSampler(prior_sampler, target_dist, data)
end

function _rand(rng::AbstractRNG, ps::PathwiseSampler)
prior = rand(rng, ps.prior_sampler)
f = prior(ps.data.x)

ts = rand(rng, ps.target_sampler)

v = ps.data.C \ (ts - f)

return (prior=prior, v=v)
end

function eval_at(basis::KernelBasis, s, x::AbstractArray)
return s.prior(x) .+ dot.(basis.(x), Ref(s.v))
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ include("test_util.jl")
println(" ")
@info "Ran latent_gp tests"

include("sampling.jl")
println(" ")
@info "Ran sampling tests"

include("deprecations.jl")
println(" ")
@info "Ran deprecation tests"
Expand Down
Loading
Loading