From bb08e1f364cda6a4938cf00517cc6a96c324a3b8 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 19 Aug 2024 03:58:35 +0000 Subject: [PATCH] Doc: formatting --- blackjax/vi/fullrank_vi.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 009de960e..4ae17fcf6 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -39,12 +39,12 @@ class FRVIState(NamedTuple): Mean of the Gaussian approximation. chol_params: Flattened Cholesky factor of the Gaussian approximation, used to parameterize - the full-rank covariance matrix. A vector of length d(d+1)/2 for a - d-dimensional Gaussian, containing d diagonal elements (in log space) followed + the full-rank covariance matrix. A vector of length d(d+1)/2 for a + d-dimensional Gaussian, containing d diagonal elements (in log space) followed by lower triangular elements in row-major order. opt_state: Optax optimizer state. - + """ mu: ArrayTree @@ -54,11 +54,12 @@ class FRVIState(NamedTuple): class FRVIInfo(NamedTuple): """Extra information of the full-rank VI algorithm. - + elbo: ELBO of approximation wrt target distribution. """ + elbo: float @@ -168,42 +169,42 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. - The Cholesky factor (L) is a lower triangular matrix with positive diagonal - elements used to parameterize the (full-rank) covariance matrix of the Gaussian + The Cholesky factor (L) is a lower triangular matrix with positive diagonal + elements used to parameterize the (full-rank) covariance matrix of the Gaussian approximation as Sigma = LL^T. - + This parameterization allows for (1) efficient sampling and log density evaluation, - and (2) ensuring the covariance matrix is symmetric and positive definite during + and (2) ensuring the covariance matrix is symmetric and positive definite during (unconconstrained) optimization. - + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) - into its proper lower triangular matrix form (`chol_factor`). It specifically - reshapes the input vector `chol_params` into a lower triangular matrix with zeros + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros above the diagonal and exponentiates the diagonal elements to ensure positivity. Parameters ---------- chol_params Flattened Cholesky factor of the full-rank covariance matrix. - + Returns ------- chol_factor Cholesky factor of the full-rank covariance matrix. """ - + n = chol_params.size dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) - diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? + diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? chol_factor = tril + jnp.diag(diag) return chol_factor def _sample(rng_key, mu, chol_params, num_samples): """Sample from the full-rank Gaussian approximation of the target distribution. - + Parameters ---------- rng_key @@ -214,7 +215,7 @@ def _sample(rng_key, mu, chol_params, num_samples): Flattened Cholesky factor of the Gaussian approximation. num_samples Number of samples to draw. - + Returns ------- Samples drawn from the full-rank Gaussian approximation.