diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index d65fadf85..009de960e 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -14,6 +14,7 @@ from typing import Callable, NamedTuple import jax +import jax.flatten_util import jax.numpy as jnp import jax.scipy as jsp from optax import GradientTransformation, OptState @@ -32,12 +33,32 @@ class FRVIState(NamedTuple): + """State of the full-rank VI algorithm. + + mu: + Mean of the Gaussian approximation. + chol_params: + Flattened Cholesky factor of the Gaussian approximation, used to parameterize + the full-rank covariance matrix. A vector of length d(d+1)/2 for a + d-dimensional Gaussian, containing d diagonal elements (in log space) followed + by lower triangular elements in row-major order. + opt_state: + Optax optimizer state. + + """ + mu: ArrayTree - chol_params: ArrayTree # flattened Cholesky factor + chol_params: ArrayTree opt_state: OptState class FRVIInfo(NamedTuple): + """Extra information of the full-rank VI algorithm. + + elbo: + ELBO of approximation wrt target distribution. + + """ elbo: float @@ -47,10 +68,10 @@ def init( *optimizer_args, **optimizer_kwargs, ) -> FRVIState: - """Initialize the full-rank VI state.""" + """Initialize the full-rank VI state with zero mean and identity covariance.""" mu = jax.tree.map(jnp.zeros_like, position) dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0] - chol_params, _ = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim))) + chol_params = jnp.zeros(dim * (dim + 1) // 2) opt_state = optimizer.init((mu, chol_params)) return FRVIState(mu, chol_params, opt_state) @@ -63,7 +84,7 @@ def step( num_samples: int = 5, stl_estimator: bool = True, ) -> tuple[FRVIState, FRVIInfo]: - """Approximate the target density using the full-rank approximation. + """Approximate the target density using the full-rank Gaussian approximation Parameters ---------- @@ -92,7 +113,7 @@ def kl_divergence_fn(parameters): mu, chol_params = parameters z = _sample(rng_key, mu, chol_params, num_samples) if stl_estimator: - parameters = jax.tree_map(jax.lax.stop_gradient, (mu, chol_params)) + parameters = jax.tree.map(jax.lax.stop_gradient, (mu, chol_params)) logq = jax.vmap(generate_fullrank_logdensity(mu, chol_params))(z) logp = jax.vmap(logdensity_fn)(z) return (logq - logp).mean() @@ -147,30 +168,61 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. - Transforms a flattened vector representation of a lower triangular matrix - into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements - consisting of d diagonal elements followed by n-d off-diagonal elements in - row-major order, where d is the dimension of the matrix. - - The diagonal elements are passed through a softplus function to ensure (numerically - stable) positivity, such that the resulting Cholesky factor is positive definite. + The Cholesky factor (L) is a lower triangular matrix with positive diagonal + elements used to parameterize the (full-rank) covariance matrix of the Gaussian + approximation as Sigma = LL^T. + + This parameterization allows for (1) efficient sampling and log density evaluation, + and (2) ensuring the covariance matrix is symmetric and positive definite during + (unconconstrained) optimization. + + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros + above the diagonal and exponentiates the diagonal elements to ensure positivity. - This parameterization allows for unconstrained optimization while ensuring the - resulting covariance matrix Sigma = CC^T is symmetric and positive definite. + Parameters + ---------- + chol_params + Flattened Cholesky factor of the full-rank covariance matrix. + + Returns + ------- + chol_factor + Cholesky factor of the full-rank covariance matrix. """ + n = chol_params.size dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) - diag = jax.nn.softplus(chol_params[:dim]) + diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? chol_factor = tril + jnp.diag(diag) return chol_factor def _sample(rng_key, mu, chol_params, num_samples): + """Sample from the full-rank Gaussian approximation of the target distribution. + + Parameters + ---------- + rng_key + Key for JAX's pseudo-random number generator. + mu + Mean of the Gaussian approximation. + chol_params + Flattened Cholesky factor of the Gaussian approximation. + num_samples + Number of samples to draw. + + Returns + ------- + Samples drawn from the full-rank Gaussian approximation. + + """ mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) chol_factor = _unflatten_cholesky(chol_params) - eps = jax.random.normal(rng_key, (num_samples, mu_flatten.size)) + eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) flatten_sample = eps @ chol_factor.T + mu_flatten return jax.vmap(unravel_fn)(flatten_sample) diff --git a/tests/vi/test_fullrank_vi.py b/tests/vi/test_fullrank_vi.py index a8b35d131..411e0c8d0 100644 --- a/tests/vi/test_fullrank_vi.py +++ b/tests/vi/test_fullrank_vi.py @@ -11,9 +11,9 @@ class FullRankVITest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) - @chex.variants(with_jit=True, without_jit=True) + @chex.variants(with_jit=True, without_jit=False) def test_recover_posterior(self): ground_truth = [ # loc, scale @@ -38,7 +38,7 @@ def logdensity_fn(x): rng_key = self.key for i in range(num_steps): - subkey = jax.random.split(rng_key, i) + subkey = jax.random.fold_in(rng_key, i) state, _ = self.variant(frvi.step)(subkey, state) loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]