Skip to content

Commit

Permalink
Merge pull request #93 from ziatdinovmax/bnn2
Browse files Browse the repository at this point in the history
Fixes predictions with structured probabilistic model
  • Loading branch information
ziatdinovmax authored Mar 17, 2024
2 parents 14d1b50 + 784a6f3 commit 64bbec2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
2 changes: 1 addition & 1 deletion gpax/models/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarra

def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
b = numpyro.sample(name=name, fn=dist.Cauchy(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b

Expand Down
34 changes: 29 additions & 5 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import jax
import jaxlib
import jax.numpy as jnp
import jax.random as jra
from jax import vmap
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median
Expand Down Expand Up @@ -144,19 +146,44 @@ def sample_from_prior(self, rng_key: jnp.ndarray,
prior_predictive = Predictive(self.model, num_samples=num_samples)
samples = prior_predictive(rng_key, X)
return samples['y']

def sample_single_posterior_predictive(self, rng_key, X_new, params, n_draws):
sigma = params["noise"]
loc = self._model(X_new, params)
sample = dist.Normal(loc, sigma).sample(rng_key, (n_draws,)).mean(0)
return loc, sample

def _vmap_predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n_draws: int = 1,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Helper method to vectorize predictions over posterior samples
"""
if samples is None:
samples = self.get_samples(chain_dim=False)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)

predictive = lambda p1, p2: self.sample_single_posterior_predictive(p1, X_new, p2, n_draws)
loc, f_samples = vmap(predictive)(*vmap_args)

return loc, f_samples

def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False, take_point_predictions_mean: bool = True,
device: Type[jaxlib.xla_extension.Device] = None
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new points using sampled GP hyperparameters
Make prediction at X_new points using posterior model parameters
Args:
rng_key: random number generator key
X_new: 2D vector with new/'test' data of :math:`n x num_features` dimensionality
samples: optional posterior samples
n: number of samples to draw from normal distribution per single HMC sample
filter_nans: filter out samples containing NaN values (if any)
take_point_predictions_mean: take a mean of point predictions (without sampling from the normal distribution)
device:
Expand All @@ -172,10 +199,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
if device:
X_new = jax.device_put(X_new, device)
samples = jax.device_put(samples, device)
predictive = Predictive(
self.model, posterior_samples=samples, parallel=True)
y_pred = predictive(rng_key, X_new)
y_pred, y_sampled = y_pred["mu"], y_pred["y"]
y_pred, y_sampled = self._vmap_predict(rng_key, X_new, samples, n)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
y_sampled = jnp.array(y_sampled_)
Expand Down
20 changes: 16 additions & 4 deletions tests/test_spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,24 @@ def test_get_samples():
def test_prediction():
rng_keys = get_keys()
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
samples = {"a": jax.random.normal(rng_keys[0], shape=(100, 1)),
"b": jax.random.normal(rng_keys[0], shape=(100,))}
X_test = onp.linspace(X.min(), X.max(), 200)
samples = {"a": jax.random.normal(rng_keys[0], shape=(100,)),
"b": jax.random.normal(rng_keys[0], shape=(100,)),
"noise": jax.random.normal(rng_keys[0], shape=(100,))}
m =sPM(model, model_priors)
y_mean, y_sampled = m.predict(rng_keys[1], X_test, samples)
assert isinstance(y_mean, jnp.ndarray)
assert isinstance(y_sampled, jnp.ndarray)
assert_equal(y_mean.shape, X_test.squeeze().shape)
assert_equal(y_sampled.shape, (100, X_test.shape[0]))
assert_equal(y_sampled.shape, (100, X_test.shape[0]))


def test_fit_predict():
key1, key2 = get_keys()
X, y = get_dummy_data()
X_test = onp.linspace(X.min(), X.max(), 200)
m = sPM(model, model_priors)
m.fit(key1, X, y, num_warmup=100, num_samples=100)
y_mean, y_sampled = m.predict(key2, X_test)
assert_equal(y_mean.shape, X_test.squeeze().shape)
assert_equal(y_sampled.shape, (100, X_test.shape[0]))

0 comments on commit 64bbec2

Please sign in to comment.