Skip to content

Commit

Permalink
Fix: Non-jitted full-rank VI works
Browse files Browse the repository at this point in the history
Fix testing bug, add docstrings, and change softmax to exponential when
converting `chol_params` to `chol_factor` in `_unflatten_cholesky`.
  • Loading branch information
gil2rok committed Aug 19, 2024
1 parent 6b9c002 commit 4379a6d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
84 changes: 68 additions & 16 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Callable, NamedTuple

import jax
import jax.flatten_util
import jax.numpy as jnp
import jax.scipy as jsp
from optax import GradientTransformation, OptState
Expand All @@ -32,12 +33,32 @@


class FRVIState(NamedTuple):
"""State of the full-rank VI algorithm.
mu:
Mean of the Gaussian approximation.
chol_params:
Flattened Cholesky factor of the Gaussian approximation, used to parameterize
the full-rank covariance matrix. A vector of length d(d+1)/2 for a
d-dimensional Gaussian, containing d diagonal elements (in log space) followed
by lower triangular elements in row-major order.
opt_state:
Optax optimizer state.
"""

mu: ArrayTree
chol_params: ArrayTree # flattened Cholesky factor
chol_params: ArrayTree
opt_state: OptState


class FRVIInfo(NamedTuple):
"""Extra information of the full-rank VI algorithm.
elbo:
ELBO of approximation wrt target distribution.
"""
elbo: float


Expand All @@ -47,10 +68,10 @@ def init(
*optimizer_args,
**optimizer_kwargs,
) -> FRVIState:
"""Initialize the full-rank VI state."""
"""Initialize the full-rank VI state with zero mean and identity covariance."""
mu = jax.tree.map(jnp.zeros_like, position)
dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0]
chol_params, _ = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim)))
chol_params = jnp.zeros(dim * (dim + 1) // 2)
opt_state = optimizer.init((mu, chol_params))
return FRVIState(mu, chol_params, opt_state)

Expand All @@ -63,7 +84,7 @@ def step(
num_samples: int = 5,
stl_estimator: bool = True,
) -> tuple[FRVIState, FRVIInfo]:
"""Approximate the target density using the full-rank approximation.
"""Approximate the target density using the full-rank Gaussian approximation
Parameters
----------
Expand Down Expand Up @@ -92,7 +113,7 @@ def kl_divergence_fn(parameters):
mu, chol_params = parameters
z = _sample(rng_key, mu, chol_params, num_samples)
if stl_estimator:
parameters = jax.tree_map(jax.lax.stop_gradient, (mu, chol_params))
parameters = jax.tree.map(jax.lax.stop_gradient, (mu, chol_params))
logq = jax.vmap(generate_fullrank_logdensity(mu, chol_params))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()
Expand Down Expand Up @@ -147,30 +168,61 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int):
def _unflatten_cholesky(chol_params):
"""Construct the Cholesky factor from a flattened vector of Cholesky parameters.
Transforms a flattened vector representation of a lower triangular matrix
into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements
consisting of d diagonal elements followed by n-d off-diagonal elements in
row-major order, where d is the dimension of the matrix.
The diagonal elements are passed through a softplus function to ensure (numerically
stable) positivity, such that the resulting Cholesky factor is positive definite.
The Cholesky factor (L) is a lower triangular matrix with positive diagonal
elements used to parameterize the (full-rank) covariance matrix of the Gaussian
approximation as Sigma = LL^T.
This parameterization allows for (1) efficient sampling and log density evaluation,
and (2) ensuring the covariance matrix is symmetric and positive definite during
(unconconstrained) optimization.
Transforms a flattened vector representation of the Cholesky factor (`chol_params`)
into its proper lower triangular matrix form (`chol_factor`). It specifically
reshapes the input vector `chol_params` into a lower triangular matrix with zeros
above the diagonal and exponentiates the diagonal elements to ensure positivity.
This parameterization allows for unconstrained optimization while ensuring the
resulting covariance matrix Sigma = CC^T is symmetric and positive definite.
Parameters
----------
chol_params
Flattened Cholesky factor of the full-rank covariance matrix.
Returns
-------
chol_factor
Cholesky factor of the full-rank covariance matrix.
"""

n = chol_params.size
dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2
tril = jnp.zeros((dim, dim))
tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:])
diag = jax.nn.softplus(chol_params[:dim])
diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus?
chol_factor = tril + jnp.diag(diag)
return chol_factor


def _sample(rng_key, mu, chol_params, num_samples):
"""Sample from the full-rank Gaussian approximation of the target distribution.
Parameters
----------
rng_key
Key for JAX's pseudo-random number generator.
mu
Mean of the Gaussian approximation.
chol_params
Flattened Cholesky factor of the Gaussian approximation.
num_samples
Number of samples to draw.
Returns
-------
Samples drawn from the full-rank Gaussian approximation.
"""
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
chol_factor = _unflatten_cholesky(chol_params)
eps = jax.random.normal(rng_key, (num_samples, mu_flatten.size))
eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape)
flatten_sample = eps @ chol_factor.T + mu_flatten
return jax.vmap(unravel_fn)(flatten_sample)

Expand Down
6 changes: 3 additions & 3 deletions tests/vi/test_fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
class FullRankVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

@chex.variants(with_jit=True, without_jit=True)
@chex.variants(with_jit=True, without_jit=False)
def test_recover_posterior(self):
ground_truth = [
# loc, scale
Expand All @@ -38,7 +38,7 @@ def logdensity_fn(x):

rng_key = self.key
for i in range(num_steps):
subkey = jax.random.split(rng_key, i)
subkey = jax.random.fold_in(rng_key, i)
state, _ = self.variant(frvi.step)(subkey, state)

loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
Expand Down

0 comments on commit 4379a6d

Please sign in to comment.