Skip to content

Commit

Permalink
Additional edits
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Dec 23, 2024
1 parent f9c0f48 commit d7d368c
Show file tree
Hide file tree
Showing 2 changed files with 474 additions and 304 deletions.
623 changes: 346 additions & 277 deletions examples/time_series/Euler-Maruyama_and_SDEs.ipynb

Large diffs are not rendered by default.

155 changes: 128 additions & 27 deletions examples/time_series/Euler-Maruyama_and_SDEs.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,12 @@ run_control:
slideshow:
slide_type: '-'
---
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import scipy as sp
# Ignore UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)
RANDOM_SEED = 8927
np.random.seed(RANDOM_SEED)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -104,19 +96,16 @@ run_control:
slideshow:
slide_type: subslide
---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
ax1.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
ax1.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
ax1.set_title("Transient")
ax1.legend()
ax2.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
ax2.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
ax2.set_title("All time")
ax2.legend()
plt.tight_layout()
plt.figure(figsize=(10, 3))
plt.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
plt.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
plt.title("Transient")
plt.legend()
plt.subplot(122)
plt.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
plt.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
plt.title("All time")
plt.legend();
```

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
Expand All @@ -134,7 +123,7 @@ new_sheet: false
run_control:
read_only: false
---
def lin_sde(x, lam, s2):
def lin_sde(x, lam):
return lam * x, s2
```

Expand All @@ -155,12 +144,11 @@ slideshow:
---
with pm.Model() as model:
# uniform prior, but we know it must be negative
l = pm.HalfCauchy("l", beta=1)
s = pm.Uniform("s", 0.005, 0.5)
l = pm.Flat("l")
# "hidden states" following a linear SDE distribution
# parametrized by time step (det. variable) and lam (random variable)
xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(-l, s**2), shape=N, initval=x_t)
xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(l,), shape=N)
# predicted observation
zh = pm.Normal("zh", mu=xh, sigma=5e-3, observed=z_t)
Expand All @@ -178,7 +166,7 @@ run_control:
read_only: false
---
with model:
trace = pm.sample(nuts_sampler="nutpie", random_seed=RANDOM_SEED, target_accept=0.99)
trace = pm.sample()
```

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
Expand All @@ -197,7 +185,7 @@ plt.plot(x_t, "r", label="$x(t)$")
plt.legend()
plt.subplot(122)
plt.hist(-1 * az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
plt.hist(az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
plt.axvline(lam, color="r", label=r"$\lambda$", alpha=0.5)
plt.legend();
```
Expand Down Expand Up @@ -230,6 +218,119 @@ plt.legend();

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}

Note that

- inference also estimates the initial conditions
- the observed data $z(t)$ lies fully within the 95% interval of the PPC.
- there are many other ways of evaluating fit

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "slide"}}

### Toy model 2

As the next model, let's use a 2D deterministic oscillator,
\begin{align}
\dot{x} &= \tau (x - x^3/3 + y) \\
\dot{y} &= \frac{1}{\tau} (a - x)
\end{align}

with noisy observation $z(t) = m x + (1 - m) y + N(0, 0.05)$.

```{code-cell} ipython3
N, tau, a, m, s2 = 200, 3.0, 1.05, 0.2, 1e-1
xs, ys = [0.0], [1.0]
for i in range(N):
x, y = xs[-1], ys[-1]
dx = tau * (x - x**3.0 / 3.0 + y)
dy = (1.0 / tau) * (a - x)
xs.append(x + dt * dx + np.sqrt(dt) * s2 * np.random.randn())
ys.append(y + dt * dy + np.sqrt(dt) * s2 * np.random.randn())
xs, ys = np.array(xs), np.array(ys)
zs = m * xs + (1 - m) * ys + np.random.randn(xs.size) * 0.1
plt.figure(figsize=(10, 2))
plt.plot(xs, label="$x(t)$")
plt.plot(ys, label="$y(t)$")
plt.plot(zs, label="$z(t)$")
plt.legend()
```

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}

Now, estimate the hidden states $x(t)$ and $y(t)$, as well as parameters $\tau$, $a$ and $m$.

As before, we rewrite our SDE as a function returned drift & diffusion coefficients:

```{code-cell} ipython3
---
button: false
new_sheet: false
run_control:
read_only: false
---
def osc_sde(xy, tau, a):
x, y = xy[:, 0], xy[:, 1]
dx = tau * (x - x**3.0 / 3.0 + y)
dy = (1.0 / tau) * (a - x)
dxy = pt.stack([dx, dy], axis=0).T
return dxy, s2
```

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}

As before, the Euler-Maruyama discretization of the SDE is written as a prediction of the state at step $i+1$ based on the state at step $i$.

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}

We can now write our statistical model as before, with uninformative priors on $\tau$, $a$ and $m$:

```{code-cell} ipython3
---
button: false
new_sheet: false
run_control:
read_only: false
---
xys = np.c_[xs, ys]
with pm.Model() as model:
tau_h = pm.Uniform("tau_h", lower=0.1, upper=5.0)
a_h = pm.Uniform("a_h", lower=0.5, upper=1.5)
m_h = pm.Uniform("m_h", lower=0.0, upper=1.0)
xy_h = pm.EulerMaruyama(
"xy_h", dt=dt, sde_fn=osc_sde, sde_pars=(tau_h, a_h), shape=xys.shape, initval=xys
)
zh = pm.Normal("zh", mu=m_h * xy_h[:, 0] + (1 - m_h) * xy_h[:, 1], sigma=0.1, observed=zs)
```

```{code-cell} ipython3
pm.__version__
```

```{code-cell} ipython3
---
button: false
new_sheet: false
run_control:
read_only: false
---
with model:
pm.sample_posterior_predictive(trace, extend_inferencedata=True)
```

```{code-cell} ipython3
plt.figure(figsize=(10, 3))
plt.plot(
trace.posterior_predictive.quantile((0.025, 0.975), dim=("chain", "draw"))["zh"].values.T,
"k",
label=r"$z_{95\% PP}(t)$",
)
plt.plot(z_t, "r", label="$z(t)$")
plt.legend();
```

+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}

Note that the initial conditions are also estimated, and that most of the observed data $z(t)$ lies within the 95% interval of the PPC.

Another approach is to look at draws from the sampling distribution of the data relative to the observed data. This too shows a good fit across the range of observations -- the posterior predictive mean almost perfectly tracks the data.
Expand Down

0 comments on commit d7d368c

Please sign in to comment.