Skip to content

Commit

Permalink
Improve testing, add a new test, fix some things (#36)
Browse files Browse the repository at this point in the history
First estimate the transition density `p_T(0,v)` for fix `T` of the non-linear diffusion `P` by forward-simulating paths and then taking an empirical, marginal distribution at time `T`. Then, compare to an alternative estimation method, where for *each value of `v`* (where `v`s are simulated at random) compute the stochastic importance weight by simulating proposal bridge and then assigning weight that corrects for the discrepancy between the proposal and the target. The contribution from the path is averaged out and only the contribution from the transition density from  the starting point to end-point `v` is left and it also comes from the correct law, because of the importance correction step.
  • Loading branch information
mschauer authored Sep 17, 2019
1 parent a5859fb commit 8803e2c
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 50 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
name = "BridgeSDEInference"
uuid = "46d747a0-b9e1-11e9-14b5-615c73e45078"
authors = ["Marcin Mider <[email protected]>", "mschauer <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
Bridge = "2d3116d5-4b8f-5680-861c-71f149790274"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "Statistics", "Random", "LinearAlgebra"]

[compat]
julia = "1"
6 changes: 3 additions & 3 deletions src/guid_prop_bridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,11 @@ function llikelihood(::LeftRule, X::SamplePath, P::GuidPropBridge; skip = 0)
* (tt[i+1]-tt[i]) )

if !constdiff(P)
H = H((i,s), x, P)
Hi = H((i,s), x, P)
som -= ( 0.5*tr( (a((i,s), x, target(P))
- aitilde((i,s), x, P))*H ) * (tt[i+1]-tt[i]) )
- a((i,s), x, auxiliary(P)))*Hi ) * (tt[i+1]-tt[i]) )
som += ( 0.5*( r'*(a((i,s), x, target(P))
- aitilde((i,s), x, P))*r ) * (tt[i+1]-tt[i]) )
- a((i,s), x, auxiliary(P)))*r ) * (tt[i+1]-tt[i]) )
end
end
som
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using Test

using Bridge, StaticArrays, Distributions
using Bridge, BridgeSDEInference, StaticArrays, Distributions
using Statistics, Random, LinearAlgebra
POSSIBLE_PARAMS = [:regular, :simpleAlter, :complexAlter, :simpleConjug,
:complexConjug]
SRC_DIR = joinpath(Base.source_dir(), "..", "src")

const BSI = BridgeSDEInference
using BridgeSDEInference:

include("test_ODE_solver_change_pt.jl")
include("test_blocking.jl")
include("test_measchange.jl")
21 changes: 11 additions & 10 deletions test/test_ODE_solver_change_pt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@


POSSIBLE_PARAMS = [:regular, :simpleAlter, :complexAlter, :simpleConjug,
:complexConjug]

function change_point_test_prep(N=10000, λ=N/10)
L = @SMatrix [1. 0.;
0. 1.]
Expand All @@ -8,34 +13,30 @@ function change_point_test_prep(N=10000, λ=N/10)

param = :complexConjug
# Target law
= FitzhughDiffusion(param, θ₀...)
= BSI.FitzhughDiffusion(param, θ₀...)

# Auxiliary law
t₀ = 1.0
T = 2.0
x0 = {2}(-0.5, 2.25)
xT = {2}(1.0, 0.0)
= FitzhughDiffusionAux(param, θ₀..., t₀, L*x0, T, L*xT)
= BSI.FitzhughDiffusionAux(param, θ₀..., t₀, L*x0, T, L*xT)

τ(t₀,T) = (x) -> t₀ + (x-t₀) * (2-(x-t₀)/(T-t₀))
dt = (T-t₀)/N
tt = τ(t₀,T).(t₀:dt:T)

P₁ = GuidPropBridge(eltype(x0), tt, P˟, P̃, L, L*x0, Σ;
changePt=NoChangePt(), solver=Vern7())
P₁ = BSI.GuidPropBridge(eltype(x0), tt, P˟, P̃, L, L*x0, Σ;
changePt=BSI.NoChangePt(), solver=BSI.Vern7())

P₂ = GuidPropBridge(eltype(x0), tt, P˟, P̃, L, L*x0, Σ;
changePt=SimpleChangePt(λ), solver=Vern7())
P₂ = BSI.GuidPropBridge(eltype(x0), tt, P˟, P̃, L, L*x0, Σ;
changePt=BSI.SimpleChangePt(λ), solver=BSI.Vern7())
P₁, P₂
end

@testset "change point between ODE solvers" begin

parametrisation = POSSIBLE_PARAMS[5]
include(joinpath(SRC_DIR, "fitzHughNagumo.jl"))
include(joinpath(SRC_DIR, "types.jl"))
include(joinpath(SRC_DIR, "vern7.jl"))
include(joinpath(SRC_DIR, "guid_prop_bridge.jl"))
N = 10000
P₁, P₂ = change_point_test_prep(N)
@testset "comparing H[$i]" for i in 1:div(N,20):N
Expand Down
59 changes: 28 additions & 31 deletions test/test_blocking.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
parametrisation = POSSIBLE_PARAMS[5]
include(joinpath(SRC_DIR, "types.jl"))
include(joinpath(SRC_DIR, "fitzHughNagumo.jl"))
include(joinpath(SRC_DIR, "guid_prop_bridge.jl"))
include(joinpath(SRC_DIR, "blocking_schedule.jl"))
include(joinpath(SRC_DIR, "vern7.jl"))

POSSIBLE_PARAMS = [:regular, :simpleAlter, :complexAlter, :simpleConjug,
:complexConjug]
SRC_DIR = joinpath(Base.source_dir(), "..", "src")

parametrisation = POSSIBLE_PARAMS[5]

function blocking_test_prep(obs=.([1.0, 1.2, 0.8, 1.3, 2.0]),
tt=[0.0, 1.0, 1.5, 2.3, 4.0],
knots=collect(1:length(obs)-2)[1:1:end],
changePtBuffer=100)
θ₀ = [10.0, -8.0, 25.0, 0.0, 3.0]
= FitzhughDiffusion(θ₀...)
= [FitzhughDiffusionAux(θ₀..., t₀, u[1], T, v[1]) for (t₀,T,u,v)
= BSI.FitzhughDiffusion(θ₀...)
= [BSI.FitzhughDiffusionAux(θ₀..., t₀, u[1], T, v[1]) for (t₀,T,u,v)
in zip(tt[1:end-1], tt[2:end], obs[1:end-1], obs[2:end])]
L = @SMatrix [1. 0.]
Σdiagel = 10^(-10)
Expand All @@ -23,32 +20,32 @@ function blocking_test_prep(obs=ℝ.([1.0, 1.2, 0.8, 1.3, 2.0]),
Σs =for _ in P̃]
τ(t₀,T) = (x) -> t₀ + (x-t₀) * (2-(x-t₀)/(T-t₀))
m = length(obs) - 1
P = Array{ContinuousTimeProcess,1}(undef,m)
P = Array{BSI.ContinuousTimeProcess,1}(undef,m)
dt = 1/50
for i in m:-1:1
numPts = Int64(ceil((tt[i+1]-tt[i])/dt))+1
t = τ(tt[i], tt[i+1]).( range(tt[i], stop=tt[i+1], length=numPts) )
P[i] = ( (i==m) ? GuidPropBridge(Float64, t, P˟, P̃[i], Ls[i], obs[i+1], Σs[i];
changePt=NoChangePt(changePtBuffer),
solver=Vern7()) :
GuidPropBridge(Float64, t, P˟, P̃[i], Ls[i], obs[i+1], Σs[i],
P[i] = ( (i==m) ? BSI.GuidPropBridge(Float64, t, P˟, P̃[i], Ls[i], obs[i+1], Σs[i];
changePt=BSI.NoChangePt(changePtBuffer),
solver=BSI.Vern7()) :
BSI.GuidPropBridge(Float64, t, P˟, P̃[i], Ls[i], obs[i+1], Σs[i],
P[i+1].H[1], P[i+1].Hν[1], P[i+1].c[1];
changePt=NoChangePt(changePtBuffer),
solver=Vern7()) )
changePt=BSI.NoChangePt(changePtBuffer),
solver=BSI.Vern7()) )
end

T = SArray{Tuple{2},Float64,1,2}
TW = typeof(sample([0], Wiener{Float64}()))
TX = typeof(SamplePath([], zeros(T, 0)))
TW = typeof(sample([0], BSI.Wiener{Float64}()))
TX = typeof(BSI.SamplePath([], zeros(T, 0)))
XX = Vector{TX}(undef,m)
WW = Vector{TW}(undef,m)
for i in 1:m
XX[i] = SamplePath(P[i].tt, zeros(T, length(P[i].tt)))
XX[i] = BSI.SamplePath(P[i].tt, zeros(T, length(P[i].tt)))
XX[i].yy .= [T(obs[i+1][1], i) for _ in 1:length(XX[i].yy)]
end

blockingParams = (knots, 10^(-7), SimpleChangePt(changePtBuffer))
𝔅 = ChequeredBlocking(blockingParams..., P, WW, XX)
blockingParams = (knots, 10^(-7), BSI.SimpleChangePt(changePtBuffer))
𝔅 = BSI.ChequeredBlocking(blockingParams..., P, WW, XX)
for i in 1:m
𝔅.XXᵒ[i].yy .= [T(obs[i+1][1], 10+i) for _ in 1:length(XX[i].yy)]
end
Expand All @@ -70,8 +67,8 @@ end
@test 𝔅.knots[2] == [2]
@test 𝔅.blocks[1] == [[1], [2, 3], [4]]
@test 𝔅.blocks[2] == [[1, 2], [3, 4]]
@test 𝔅.changePts[1] == [SimpleChangePt(100), NoChangePt(100), SimpleChangePt(100), NoChangePt(100)]
@test 𝔅.changePts[2] == [NoChangePt(100), SimpleChangePt(100), NoChangePt(100), NoChangePt(100)]
@test 𝔅.changePts[1] == [BSI.SimpleChangePt(100), BSI.NoChangePt(100), BSI.SimpleChangePt(100), BSI.NoChangePt(100)]
@test 𝔅.changePts[2] == [BSI.NoChangePt(100), BSI.SimpleChangePt(100), BSI.NoChangePt(100), BSI.NoChangePt(100)]
@test 𝔅.vs == obs[2:end]
@test 𝔅.Ls[1] == [I, L, I, L]
@test 𝔅.Ls[2] == [L, I, L, L]
Expand All @@ -80,18 +77,18 @@ end
end

θ = [10.0, -8.0, 15.0, 0.0, 3.0]
𝔅 = next(𝔅, 𝔅.XX, θ)
𝔅 = BSI.next(𝔅, 𝔅.XX, θ)

@testset "validity of blocking state after calling next" begin
@test 𝔅.idx == 2
@testset "checking if θ has been propagated everywhere" for i in 1:length(tt)-1
@test params(𝔅.P[i].Target) == θ
@test params(𝔅.P[i].Pt) == θ
@test BSI.params(𝔅.P[i].Target) == θ
@test BSI.params(𝔅.P[i].Pt) == θ
end
@test [𝔅.P[i].Σ for i in 1:length(tt)-1 ] == 𝔅.Σs[2] == [Σ, I*ϵ, Σ, Σ]
@test [𝔅.P[i].L for i in 1:length(tt)-1 ] == 𝔅.Ls[2] == [L, I, L, L]
@test [𝔅.P[i].v for i in 1:length(tt)-1 ] == [obs[2], 𝔅.XX[2].yy[end], obs[4], obs[5]]
@test [𝔅.P[i].changePt for i in 1:length(tt)-1 ] == 𝔅.changePts[2] == [NoChangePt(100), SimpleChangePt(100), NoChangePt(100), NoChangePt(100)]
@test [𝔅.P[i].changePt for i in 1:length(tt)-1 ] == 𝔅.changePts[2] == [BSI.NoChangePt(100), BSI.SimpleChangePt(100), BSI.NoChangePt(100), BSI.NoChangePt(100)]
end

θᵒ = [1.0, -7.0, 10.0, 2.0, 1.0]
Expand All @@ -110,17 +107,17 @@ end
end
end

𝔅 = next(𝔅, 𝔅.XX, θᵒ)
𝔅 = BSI.next(𝔅, 𝔅.XX, θᵒ)

@testset "validity of blocking state after second call to next" begin
@test 𝔅.idx == 1
@testset "checking if θᵒ has been propagated everywhere" for i in 1:length(tt)-1
@test params(𝔅.P[i].Target) == θᵒ
@test params(𝔅.P[i].Pt) == θᵒ
@test BSI.params(𝔅.P[i].Target) == θᵒ
@test BSI.params(𝔅.P[i].Pt) == θᵒ
end
@test [𝔅.P[i].Σ for i in 1:length(tt)-1 ] == 𝔅.Σs[1] == [I*ϵ, Σ, I*ϵ, Σ]
@test [𝔅.P[i].L for i in 1:length(tt)-1 ] == 𝔅.Ls[1] == [I, L, I, L]
@test [𝔅.P[i].v for i in 1:length(tt)-1 ] == [𝔅.XX[1].yy[end], obs[3], 𝔅.XX[3].yy[end], obs[5]]
@test [𝔅.P[i].changePt for i in 1:length(tt)-1 ] == 𝔅.changePts[1] == [SimpleChangePt(100), NoChangePt(100), SimpleChangePt(100), NoChangePt(100)]
@test [𝔅.P[i].changePt for i in 1:length(tt)-1 ] == 𝔅.changePts[1] == [BSI.SimpleChangePt(100), BSI.NoChangePt(100), BSI.SimpleChangePt(100), BSI.NoChangePt(100)]
end
end
115 changes: 115 additions & 0 deletions test/test_measchange.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Test that the transition density
# in a non-linear, non-homogenous, non-constant diffusivity model
# estimated by forward simulation
# agrees with the density obtained from the linearisation
# reweighted with importance weights using guided proposals.
const 𝕏 = SVector{1}
using GaussianDistributions

struct TargetSDE <: Bridge.ContinuousTimeProcess{Float64}
end
struct LinearSDE{T} <: Bridge.ContinuousTimeProcess{Float64}
σ::T
end
Bridge.b(s, x, P::TargetSDE) = -0.1x + .5sin(x[1]) + 0.5sin(s/4)
Bridge.b(s, x, P::LinearSDE) = Bridge.B(s, P)*x + Bridge.β(s, P)
Bridge.B(s, P::LinearSDE) = SMatrix{1}(-0.1)
Bridge.β(s, P::LinearSDE) = 𝕏(0.5sin(s/4))

Bridge.σ(s, x, P::TargetSDE) = SMatrix{1}(2.0 + 0.5cos(x[1]))
Bridge.σ(s, x, P::LinearSDE) = P.σ
Bridge.σ(s, P::LinearSDE) = P.σ
Bridge.a(s, P::LinearSDE) = P.σ^2

Bridge.constdiff(::TargetSDE) = false
Bridge.constdiff(::LinearSDE) = true

binind(r, x) = searchsortedfirst(r, x) - 1


function test_measchange()
Random.seed!(1)
fextra = 1.
f = 1.0
T = round(f*4*pi, digits=2)
P = TargetSDE()
v = 𝕏(pi/2)

x0 = 𝕏(-pi/2)

t = 0:0.01*f*fextra:T
t = Bridge.tofs.(t, 0, T)
W = Bridge.samplepath(t, 𝕏(0.0))

Wnr = Wiener{𝕏{Float64}}()

Σ = SMatrix{1}(0.1)
L = SMatrix{1}(1.0)
Noise = Gaussian(𝕏(0.0), Σ)

sample!(W, Wnr)
X = solve(Euler(), x0, W, P)
v1 = X.yy[end]
X.yy[end] = zero(v1)
solve!(Euler(), X, x0, W, P)
@test v1 X.yy[end]



K = 50
vrange = range(-10,10, length=K+1)
vints = [(vrange[i], vrange[i+1]) for i in 1:K]

k = 1
N = 50000

# Forward simulation

vs = Float64[]
for i in 1:N
sample!(W, Wnr)
solve!(Euler(), X, x0, W, P)
v = L*X.yy[end] + rand(Noise)
push!(vs, v[1])
end

counts = zeros(K+2)
[counts[binind(vrange, v)+1] += 1 for v in vs]
counts /= length(vs)

wcounts = zeros(K)


VProp = Uniform(-10,10)

fpt = fill(NaN, 1)
v = 𝕏(0.0)
= LinearSDE(Bridge.σ(T, v, P)) # use a law with large variance
Pᵒ = BridgeSDEInference.GuidPropBridge(eltype(x0), t, P, P̃, L, v, Σ)

# Guided proposals

for i in 1:N
v = 𝕏(5*rand(VProp))
while !((binind(vrange, v[1]) in 1:K))
v = 𝕏(5*rand(VProp))
end
# other possibility: change proposal each step
# P̃ = LinearSDE(Bridge.σ(T, v, P))
Pᵒ = BridgeSDEInference.GuidPropBridge(eltype(x0), t, P, P̃, L, v, Σ)

sample!(W, Wnr)
solve!(Euler(), X, x0, W, Pᵒ)
ll = BSI.pathLogLikhd(BridgeSDEInference.PartObs(), [X], [Pᵒ], 1:1, fpt, skipFPT=true)
ll += BSI.lobslikelihood(Pᵒ, x0)
ll -= logpdf(VProp, v[1])
wcounts[binind(vrange, v[1])] += exp(ll)/N
end
bias = wcounts - counts[2:end-1]

@testset "Statistical correctness of guided proposals" begin
@test norm(bias) < 0.05
end
end

test_measchange()

0 comments on commit 8803e2c

Please sign in to comment.