Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusive Gibbs Sampler #744

Open
jcopo opened this issue Oct 1, 2024 · 7 comments
Open

Diffusive Gibbs Sampler #744

jcopo opened this issue Oct 1, 2024 · 7 comments

Comments

@jcopo
Copy link

jcopo commented Oct 1, 2024

Presentation of the new sampler

DiGS is an auxiliary variable MCMC method where the auxiliary variable $\tilde{x}$ is a noisy version of the original variable $x$. DiGS enhances mixing and helps escape local modes by alternately sampling from the distributions $p(\tilde{x}|x)$, which introduces noise via Gaussian convolution, and $p(x|\tilde{x})$, which denoises the sample back to the original space using a score-based update (eg a Langevin diffusion).
https://arxiv.org/abs/2402.03008

I had very good results with it in small to medium dimensions. It really helps escaping local modes. A very powerful usage is to use it as a proposal in an SMC based procedure where it helps moving samples back in depleted zones. In high dimension the acceptance ratio of the MH step becomes the tricky part.

If you think it is a sensible addition to Blackjax I'll be happy to contribute.

How does it compare to other algorithms in blackjax?

The number of denoising steps is flexible so it can be computationally efficient. The algorithm is quite simple conceptually but is applicable to a wide class of problems.

Where does it fit in blackjax

As an MCMC kernel per se or as an SMC proposal.

Are you willing to open a PR?

Yes - I have a version that I used for my research and would be happy to contribute it to Blackjax

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Oct 1, 2024

Can you explain how this algorithm is different from the auxiliary perspective of MALA in https://rss.onlinelibrary.wiley.com/doi/full/10.1111/rssb.12269?

Edit: Gaussian convolution and then Langevin "denoising" is exactly aMALA in my books, so where's the difference?

@jcopo
Copy link
Author

jcopo commented Oct 1, 2024

Maybe given your comment I should underline that I have no personal interest in the paper. I was not familiar with the reference you provided but the method does look similar.

At first glance aMALA doesn't seem to have the multilevel noise schedule of Diffusive Gibbs and initialization of the denoising step (eq. 14 of DiGS) seems to be different. Is the contraction idea and the corresponding variance of eq. 10 also in aMALA ? Had a look at your code implementation in marginal_latent_gaussian.py but wasn't obvious

@AdrienCorenflos
Copy link
Contributor

The marginal latent Gaussian is the counterpart for Gaussian priors, but I really think it's related: the auxiliary target is the same one, and then it looks like it's just a bunch of MALA steps with increasing step size -> this is not unprecedented in literature (although people typically take step sizes coming from Chebyshev polynomials, for which there is theory).

I am not against having underperforming or academically not super novel samplers in the library mind you, I'm mostly thinking we may want to think carefully about the components and implement these, rather than the special instance that DiGS offers.

@jcopo
Copy link
Author

jcopo commented Oct 2, 2024

it's just a bunch of MALA steps with increasing step size

I don't see how this is true? The schedule modifies the proximal version of the score function (eq. 12) and not directly the step size.

Putting academic novelty aside or which paper gets credited, I think this is a simple yet effective sampler.
I'm happy to think about the components and how to implement these. But it's not clear to me what algorithm DiGS should be a special case of.
If you have references/ideas I'd be interested in having a look.

@AdrienCorenflos
Copy link
Contributor

But Eq (12) is immediately the gradient of the conditional proximal density though, exactly as would happen in auxiliary MALA with a different choice of decomposition in terms of $\alpha, \sigma$, I'm not sure what you mean.
image

From what I understand (I am not saying I'm not missing something though), the algo to sample from p(x, u) = p(x) p(u | x) is the following:
Given the current state $X^*$,

  1. Sample an auxiliary variable $U \sim p(u | X)$
  2. Form the conditional $p(x | U) \propto p(x, U)$
  3. Jitter (or not depending on MH accept) $X^*$
  4. Apply MALA a bunch of times for $p(x | U)$ with increasing step-sizes, which corresponds to a warm-up with a non-adaptive schedule

So, in some sense, what I can see here is that maybe we want to disconnect the step-size and the scale in our MALA algo to allow for different balancing in auxiliary schemes, but that's kind of it? Also I'm really not sure the choice of balancing they have is the best one. All in all, I'd support a small refactoring to allow for more flexible parameterization of MALA (or maybe we already do have this) and some Gaussian proximal auxiliary utility, then add the DiGS sampler as an example, not as core library one.

@jcopo
Copy link
Author

jcopo commented Oct 2, 2024

Ok yes I get you. It wasn't clear to me which part of the alg. you were referring to.

Also I'm really not sure the choice of balancing they have is the best one.

I agree but the idea of dilation/contraction of the space with $\alpha$ is interesting especially in a multimodal setting.

Is there already something in place for dealing with auxiliary variable samplers? On top of what you propose I think this could be a nice add

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Oct 2, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants