Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cholesky decomposition for ExactGP class #115

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
return self.mcmc.get_samples(group_by_chain=chain_dim)

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, use_cholesky: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
Expand All @@ -267,13 +267,24 @@ def get_mvn_posterior(
k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs)
k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0)
k_XX = self.kernel(self.X_train, self.X_train, params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

# Compute the predictive covariance and mean
# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion

if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
cov = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, y_residual))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()

return mean, cov

def _predict(
Expand All @@ -283,11 +294,12 @@ def _predict(
params: Dict[str, jnp.ndarray],
n: int,
noiseless: bool = False,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of GP parameters"""
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs)
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, use_cholesky, **kwargs)
# draw samples from the posterior predictive for a given set of parameters
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled
Expand All @@ -304,10 +316,11 @@ def _predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
if predict_fn is None:
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs)
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, use_cholesky, **kwargs)

def predict_batch(Xi):
out1, out2 = predict_fn(Xi)
Expand All @@ -333,6 +346,7 @@ def predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Expand All @@ -342,7 +356,7 @@ def predict_in_batches(
to avoid a memory overflow
"""
y_pred, y_sampled = self._predict_in_batches(
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, use_cholesky, **kwargs
)
y_pred = jnp.concatenate(y_pred, 0)
y_sampled = jnp.concatenate(y_sampled, -1)
Expand All @@ -357,6 +371,7 @@ def predict(
filter_nans: bool = False,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Expand Down Expand Up @@ -391,7 +406,7 @@ def predict(
samples = jax.device_put(samples, device)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs))
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, use_cholesky, **kwargs))
y_means, y_sampled = predictive(vmap_args)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,21 @@ def test_get_mvn_posterior_noiseless():
assert_array_equal(mean1, mean2)
assert onp.count_nonzero(cov1 - cov2) > 0

def test_get_mvn_posterior_cholesky():
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = ExactGP(1, 'RBF')
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params, use_cholesky=True)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_single_sample_prediction():
rng_key = get_keys()[0]
Expand Down
Loading