Skip to content

Commit

Permalink
Doc: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gil2rok committed Aug 19, 2024
1 parent 4379a6d commit bb08e1f
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class FRVIState(NamedTuple):
Mean of the Gaussian approximation.
chol_params:
Flattened Cholesky factor of the Gaussian approximation, used to parameterize
the full-rank covariance matrix. A vector of length d(d+1)/2 for a
d-dimensional Gaussian, containing d diagonal elements (in log space) followed
the full-rank covariance matrix. A vector of length d(d+1)/2 for a
d-dimensional Gaussian, containing d diagonal elements (in log space) followed
by lower triangular elements in row-major order.
opt_state:
Optax optimizer state.
"""

mu: ArrayTree
Expand All @@ -54,11 +54,12 @@ class FRVIState(NamedTuple):

class FRVIInfo(NamedTuple):
"""Extra information of the full-rank VI algorithm.
elbo:
ELBO of approximation wrt target distribution.
"""

elbo: float


Expand Down Expand Up @@ -168,42 +169,42 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int):
def _unflatten_cholesky(chol_params):
"""Construct the Cholesky factor from a flattened vector of Cholesky parameters.
The Cholesky factor (L) is a lower triangular matrix with positive diagonal
elements used to parameterize the (full-rank) covariance matrix of the Gaussian
The Cholesky factor (L) is a lower triangular matrix with positive diagonal
elements used to parameterize the (full-rank) covariance matrix of the Gaussian
approximation as Sigma = LL^T.
This parameterization allows for (1) efficient sampling and log density evaluation,
and (2) ensuring the covariance matrix is symmetric and positive definite during
and (2) ensuring the covariance matrix is symmetric and positive definite during
(unconconstrained) optimization.
Transforms a flattened vector representation of the Cholesky factor (`chol_params`)
into its proper lower triangular matrix form (`chol_factor`). It specifically
reshapes the input vector `chol_params` into a lower triangular matrix with zeros
into its proper lower triangular matrix form (`chol_factor`). It specifically
reshapes the input vector `chol_params` into a lower triangular matrix with zeros
above the diagonal and exponentiates the diagonal elements to ensure positivity.
Parameters
----------
chol_params
Flattened Cholesky factor of the full-rank covariance matrix.
Returns
-------
chol_factor
Cholesky factor of the full-rank covariance matrix.
"""

n = chol_params.size
dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2
tril = jnp.zeros((dim, dim))
tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:])
diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus?
diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus?
chol_factor = tril + jnp.diag(diag)
return chol_factor


def _sample(rng_key, mu, chol_params, num_samples):
"""Sample from the full-rank Gaussian approximation of the target distribution.
Parameters
----------
rng_key
Expand All @@ -214,7 +215,7 @@ def _sample(rng_key, mu, chol_params, num_samples):
Flattened Cholesky factor of the Gaussian approximation.
num_samples
Number of samples to draw.
Returns
-------
Samples drawn from the full-rank Gaussian approximation.
Expand Down

0 comments on commit bb08e1f

Please sign in to comment.