From 82bb795255acdcabe557e8c7ec39a9bc5b0d7e93 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 15 Oct 2019 14:36:12 +0100 Subject: [PATCH 1/3] Very simple separable implementation --- Project.toml | 1 + src/extras/separable.jl | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 src/extras/separable.jl diff --git a/Project.toml b/Project.toml index f0ff00b2..0efebae1 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" +Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/extras/separable.jl b/src/extras/separable.jl new file mode 100644 index 00000000..5cc404db --- /dev/null +++ b/src/extras/separable.jl @@ -0,0 +1,71 @@ +""" + RectilinearGrid{T, Txl, Txr} <: AbstractVector{T} + +A vector of length `length(xl) * length(xr)` which represents a matrix of data of size +`(length(xl), length(xr))`, the `ij`th element of which is the point `(xl[i], xr[j])`. +""" +struct RectilinearGrid{T, Txl, Txr} <: AbstractVector{T} + xl::Txl + xr::Txr + function RectilinearGrid(xl::AV{T}, xr::AV{V}) where {T, V} + return new{promote_type(T, V), typeof(xl), typeof(xr)}(xl, xr) + end +end +Base.size(x::RectilinearGrid) = (length(x),) +Base.length(x::RectilinearGrid) = length(x.xl) * length(x.xr) + + +""" + Separable{Tkl, Tkr} <: Kernel + +A kernel that is separable over two input dimensions +""" +struct Separable{Tkl, Tkr} <: Kernel + kl::Tkl + kr::Tkr +end + +pw(k::Separable, x::RectilinearGrid) = pw(k.kl, x.xl) ⊗ pw(k.kr, x.xr) +ew(k::Separable, x::RectilinearGrid) = error("Not implemented") + +pw(k::Separable, x::RectilinearGrid, x′::RectilinearGrid) = error("Not implemented") +ew(k::Separable, x::RectilinearGrid, x′::RectilinearGrid) = error("Not implemented") + +const SeparableGP = GP{<:MeanFunction, <:Separable} +const SeparableFiniteGP = FiniteGP{<:SeparableGP, <:RectilinearGrid, <:Diagonal} + +function logpdf(f::SeparableFiniteGP, y::AV{<:Real}) + + # Check that data and grid are the same lengths. + @assert length(f.x) == length(y) + + # Check that the observation noise is isotropic. Ideally move this to compile-time + # at some point,, although likely not a bottleneck. + σ²_n = first(f.Σy.diag) + @assert all(f.Σy.diag .== σ²_n) + + # Compute log marginal likelihood. A little bit uglier than I would like. Ideally we + # would just write `eigen(cov(f))` and the correct thing would happen, but the + # things aren't currently implemented correctly for this to be the case. + K_eig = eigen(cov(f.f, f.x)) + σ²_n * I + + λ, Γ = K_eig + β = Diagonal(1 ./ sqrt.(λ)) * (Γ'y) + return -(length(y) * log(2π) + logdet(K_eig) + sum(abs2, β)) / 2 +end + +function rand(rng::AbstractRNG, f::SeparableFiniteGP, N::Int) + + # Check that the observation noise is isotropic. Ideally move this to compile-time + # at some point,, although likely not a bottleneck. + σ²_n = first(f.Σy.diag) + @assert all(f.Σy.diag .== σ²_n) + + # Compute log marginal likelihood. A little bit uglier than I would like. Ideally we + # would just write `eigen(cov(f))` and the correct thing would happen, but the + # things aren't currently implemented correctly for this to be the case. + K_eig = eigen(cov(f.f, f.x)) + σ²_n * I + + λ, Γ = K_eig + return Γ * (Diagonal(sqrt.(λ)) * randn(rng, length(λ), N)) +end From fc0cc8d8923635fc6e3200cd32770df7b88b3c8e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 15 Oct 2019 15:00:38 +0100 Subject: [PATCH 2/3] Include the code... --- src/Stheno.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Stheno.jl b/src/Stheno.jl index b76b5655..b5123d23 100644 --- a/src/Stheno.jl +++ b/src/Stheno.jl @@ -59,4 +59,8 @@ module Stheno # Various stuff for convenience. include(joinpath("util", "model.jl")) include(joinpath("util", "plotting.jl")) + + # Helpful functionality that sits on top of Stheno but isn't fully integrated + # into the CompositeGP infrastructure, and requires more code than simply a new kernel. + include(joinpath("extras", "separable.jl")) end # module From 7c84f78b456f848df29b53e1721282e30f6eb344 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 15 Oct 2019 15:03:51 +0100 Subject: [PATCH 3/3] using Kronecker --- src/extras/separable.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/extras/separable.jl b/src/extras/separable.jl index 5cc404db..59f4970c 100644 --- a/src/extras/separable.jl +++ b/src/extras/separable.jl @@ -1,3 +1,5 @@ +using Kronecker + """ RectilinearGrid{T, Txl, Txr} <: AbstractVector{T}