-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve testing, add a new test, fix some things (#36)
First estimate the transition density `p_T(0,v)` for fix `T` of the non-linear diffusion `P` by forward-simulating paths and then taking an empirical, marginal distribution at time `T`. Then, compare to an alternative estimation method, where for *each value of `v`* (where `v`s are simulated at random) compute the stochastic importance weight by simulating proposal bridge and then assigning weight that corrects for the discrepancy between the proposal and the target. The contribution from the path is averaged out and only the contribution from the transition density from the starting point to end-point `v` is left and it also comes from the correct law, because of the importance correction step.
- Loading branch information
Showing
6 changed files
with
167 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,25 @@ | ||
name = "BridgeSDEInference" | ||
uuid = "46d747a0-b9e1-11e9-14b5-615c73e45078" | ||
authors = ["Marcin Mider <[email protected]>", "mschauer <[email protected]>"] | ||
version = "0.1.0" | ||
version = "0.1.1" | ||
|
||
[deps] | ||
Bridge = "2d3116d5-4b8f-5680-861c-71f149790274" | ||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" | ||
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
|
||
[extras] | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test"] | ||
test = ["Test", "Statistics", "Random", "LinearAlgebra"] | ||
|
||
[compat] | ||
julia = "1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
using Test | ||
|
||
using Bridge, StaticArrays, Distributions | ||
using Bridge, BridgeSDEInference, StaticArrays, Distributions | ||
using Statistics, Random, LinearAlgebra | ||
POSSIBLE_PARAMS = [:regular, :simpleAlter, :complexAlter, :simpleConjug, | ||
:complexConjug] | ||
SRC_DIR = joinpath(Base.source_dir(), "..", "src") | ||
|
||
const BSI = BridgeSDEInference | ||
using BridgeSDEInference: ℝ | ||
|
||
include("test_ODE_solver_change_pt.jl") | ||
include("test_blocking.jl") | ||
include("test_measchange.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Test that the transition density | ||
# in a non-linear, non-homogenous, non-constant diffusivity model | ||
# estimated by forward simulation | ||
# agrees with the density obtained from the linearisation | ||
# reweighted with importance weights using guided proposals. | ||
const 𝕏 = SVector{1} | ||
using GaussianDistributions | ||
|
||
struct TargetSDE <: Bridge.ContinuousTimeProcess{Float64} | ||
end | ||
struct LinearSDE{T} <: Bridge.ContinuousTimeProcess{Float64} | ||
σ::T | ||
end | ||
Bridge.b(s, x, P::TargetSDE) = -0.1x + .5sin(x[1]) + 0.5sin(s/4) | ||
Bridge.b(s, x, P::LinearSDE) = Bridge.B(s, P)*x + Bridge.β(s, P) | ||
Bridge.B(s, P::LinearSDE) = SMatrix{1}(-0.1) | ||
Bridge.β(s, P::LinearSDE) = 𝕏(0.5sin(s/4)) | ||
|
||
Bridge.σ(s, x, P::TargetSDE) = SMatrix{1}(2.0 + 0.5cos(x[1])) | ||
Bridge.σ(s, x, P::LinearSDE) = P.σ | ||
Bridge.σ(s, P::LinearSDE) = P.σ | ||
Bridge.a(s, P::LinearSDE) = P.σ^2 | ||
|
||
Bridge.constdiff(::TargetSDE) = false | ||
Bridge.constdiff(::LinearSDE) = true | ||
|
||
binind(r, x) = searchsortedfirst(r, x) - 1 | ||
|
||
|
||
function test_measchange() | ||
Random.seed!(1) | ||
fextra = 1. | ||
f = 1.0 | ||
T = round(f*4*pi, digits=2) | ||
P = TargetSDE() | ||
v = 𝕏(pi/2) | ||
|
||
x0 = 𝕏(-pi/2) | ||
|
||
t = 0:0.01*f*fextra:T | ||
t = Bridge.tofs.(t, 0, T) | ||
W = Bridge.samplepath(t, 𝕏(0.0)) | ||
|
||
Wnr = Wiener{𝕏{Float64}}() | ||
|
||
Σ = SMatrix{1}(0.1) | ||
L = SMatrix{1}(1.0) | ||
Noise = Gaussian(𝕏(0.0), Σ) | ||
|
||
sample!(W, Wnr) | ||
X = solve(Euler(), x0, W, P) | ||
v1 = X.yy[end] | ||
X.yy[end] = zero(v1) | ||
solve!(Euler(), X, x0, W, P) | ||
@test v1 ≈ X.yy[end] | ||
|
||
|
||
|
||
K = 50 | ||
vrange = range(-10,10, length=K+1) | ||
vints = [(vrange[i], vrange[i+1]) for i in 1:K] | ||
|
||
k = 1 | ||
N = 50000 | ||
|
||
# Forward simulation | ||
|
||
vs = Float64[] | ||
for i in 1:N | ||
sample!(W, Wnr) | ||
solve!(Euler(), X, x0, W, P) | ||
v = L*X.yy[end] + rand(Noise) | ||
push!(vs, v[1]) | ||
end | ||
|
||
counts = zeros(K+2) | ||
[counts[binind(vrange, v)+1] += 1 for v in vs] | ||
counts /= length(vs) | ||
|
||
wcounts = zeros(K) | ||
|
||
|
||
VProp = Uniform(-10,10) | ||
|
||
fpt = fill(NaN, 1) | ||
v = 𝕏(0.0) | ||
P̃ = LinearSDE(Bridge.σ(T, v, P)) # use a law with large variance | ||
Pᵒ = BridgeSDEInference.GuidPropBridge(eltype(x0), t, P, P̃, L, v, Σ) | ||
|
||
# Guided proposals | ||
|
||
for i in 1:N | ||
v = 𝕏(5*rand(VProp)) | ||
while !((binind(vrange, v[1]) in 1:K)) | ||
v = 𝕏(5*rand(VProp)) | ||
end | ||
# other possibility: change proposal each step | ||
# P̃ = LinearSDE(Bridge.σ(T, v, P)) | ||
Pᵒ = BridgeSDEInference.GuidPropBridge(eltype(x0), t, P, P̃, L, v, Σ) | ||
|
||
sample!(W, Wnr) | ||
solve!(Euler(), X, x0, W, Pᵒ) | ||
ll = BSI.pathLogLikhd(BridgeSDEInference.PartObs(), [X], [Pᵒ], 1:1, fpt, skipFPT=true) | ||
ll += BSI.lobslikelihood(Pᵒ, x0) | ||
ll -= logpdf(VProp, v[1]) | ||
wcounts[binind(vrange, v[1])] += exp(ll)/N | ||
end | ||
bias = wcounts - counts[2:end-1] | ||
|
||
@testset "Statistical correctness of guided proposals" begin | ||
@test norm(bias) < 0.05 | ||
end | ||
end | ||
|
||
test_measchange() |